Skip to content

Commit 161c739

Browse files
committed
Slight cleanup.
1 parent 7e64eb2 commit 161c739

File tree

3 files changed

+30
-46
lines changed

3 files changed

+30
-46
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Byte-compiled / optimized / DLL files
22
__pycache__/
33
*.py[cod]
4+
.pytest_cache
45

56
# C extensions
67
*.so
@@ -20,9 +21,11 @@ var/
2021
.installed.cfg
2122
*.egg
2223
*.whl
24+
.eggs
2325

2426
# IDE-related
2527
.idea/
28+
.vscode/
2629

2730
# Dev
2831
venv

Makefile

Lines changed: 0 additions & 9 deletions
This file was deleted.

lame_test.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
1-
import sys, glob
1+
import glob
2+
import os
3+
import sys
4+
5+
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'
6+
27
if glob.glob('build/lib.linux-*'):
38
sys.path.append(glob.glob('build/lib.linux-*')[0])
49

5-
from functools import partial
6-
from functools import reduce
10+
from functools import partial, reduce
711

8-
import jaxlib.mlir.ir
912
import jax
1013
import jax.numpy as jnp
11-
import jax._src.test_util as jtu
14+
import jaxlib.mlir.ir
15+
from einops import rearrange
1216
from jax import core, dtypes
13-
from jax.interpreters import xla
14-
from jax.lib import xla_client
15-
from jax.interpreters import mlir
17+
from jax.experimental import mesh_utils
18+
from jax.interpreters import mlir, xla
1619
from jax.interpreters.mlir import ir
20+
from jax.lib import xla_client
21+
from jax.sharding import Mesh, NamedSharding
22+
from jax.sharding import PartitionSpec as P
23+
from jax.sharding import PositionalSharding
1724
from jaxlib.hlo_helpers import custom_call
1825

19-
# from flash_attn_jax.flash import flash_mha_fwd, flash_mha_bwd
2026
from flash_attn_jax import flash_mha
2127

2228
if __name__ == '__main__':
2329
import time
30+
2431
import numpy as np
2532

2633
@jax.jit
@@ -50,34 +57,17 @@ def pretty(tensor):
5057
def fwd(q,k,v):
5158
return flash_mha(q,k,v)
5259

53-
# print(fwd.lower(q,k,v).as_text())
54-
55-
from jax.sharding import PositionalSharding
56-
from einops import rearrange
57-
58-
# sharding = PositionalSharding(jax.devices())
59-
devices = jax.devices()
60-
# devices = [*jax.devices(), *jax.devices(backend='cpu')]
61-
n_device = len(devices)
62-
sharding = PositionalSharding(devices).reshape(1,-1,1,1)#.replicate()
63-
64-
65-
# from jax.experimental import mesh_utils
66-
# from jax.sharding import PartitionSpec as P, Mesh
67-
# from jax.sharding import NamedSharding
68-
# devices = np.array(jax.devices()) #mesh_utils.create_device_mesh((1,))
69-
# mesh = Mesh(devices, axis_names=('x',))
70-
# sharding = NamedSharding(mesh, P(None,None,'x',None))
7160

72-
# print(mesh)
61+
# devices = jax.devices(backend='cpu')
62+
# n_device = len(devices)
63+
# sharding = PositionalSharding(devices).reshape(-1,1,1,1)#.replicate()
7364

74-
o_ref = fwd(q,k,v)
65+
devices = jax.devices(backend='gpu')
66+
with Mesh(devices, axis_names=('x',)) as mesh:
67+
sharding = NamedSharding(mesh, P(None,None,'x',None))
68+
q = jax.device_put(q, sharding)
69+
k = jax.device_put(k, sharding)
70+
v = jax.device_put(v, sharding)
71+
# jax.debug.visualize_array_sharding(rearrange(q, 'n l h d -> n (l h d)'))
72+
print(fwd.lower(q,k,v).compile().as_text())
7573

76-
q = jax.device_put(q, sharding)
77-
k = jax.device_put(k, sharding)
78-
v = jax.device_put(v, sharding)
79-
jax.debug.visualize_array_sharding(rearrange(q, 'n l h d -> n (l h d)'))
80-
print(fwd.lower(q,k,v).compile().as_text())
81-
o = fwd(q,k,v)
82-
jax.debug.visualize_array_sharding(rearrange(o, 'n l h d -> n (l h d)'))
83-
print((o - o_ref).std())

0 commit comments

Comments
 (0)