|
1 |
| -import sys, glob |
| 1 | +import glob |
| 2 | +import os |
| 3 | +import sys |
| 4 | + |
| 5 | +os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2' |
| 6 | + |
2 | 7 | if glob.glob('build/lib.linux-*'):
|
3 | 8 | sys.path.append(glob.glob('build/lib.linux-*')[0])
|
4 | 9 |
|
5 |
| -from functools import partial |
6 |
| -from functools import reduce |
| 10 | +from functools import partial, reduce |
7 | 11 |
|
8 |
| -import jaxlib.mlir.ir |
9 | 12 | import jax
|
10 | 13 | import jax.numpy as jnp
|
11 |
| -import jax._src.test_util as jtu |
| 14 | +import jaxlib.mlir.ir |
| 15 | +from einops import rearrange |
12 | 16 | from jax import core, dtypes
|
13 |
| -from jax.interpreters import xla |
14 |
| -from jax.lib import xla_client |
15 |
| -from jax.interpreters import mlir |
| 17 | +from jax.experimental import mesh_utils |
| 18 | +from jax.interpreters import mlir, xla |
16 | 19 | from jax.interpreters.mlir import ir
|
| 20 | +from jax.lib import xla_client |
| 21 | +from jax.sharding import Mesh, NamedSharding |
| 22 | +from jax.sharding import PartitionSpec as P |
| 23 | +from jax.sharding import PositionalSharding |
17 | 24 | from jaxlib.hlo_helpers import custom_call
|
18 | 25 |
|
19 |
| -# from flash_attn_jax.flash import flash_mha_fwd, flash_mha_bwd |
20 | 26 | from flash_attn_jax import flash_mha
|
21 | 27 |
|
22 | 28 | if __name__ == '__main__':
|
23 | 29 | import time
|
| 30 | + |
24 | 31 | import numpy as np
|
25 | 32 |
|
26 | 33 | @jax.jit
|
@@ -50,34 +57,17 @@ def pretty(tensor):
|
50 | 57 | def fwd(q,k,v):
|
51 | 58 | return flash_mha(q,k,v)
|
52 | 59 |
|
53 |
| - # print(fwd.lower(q,k,v).as_text()) |
54 |
| - |
55 |
| - from jax.sharding import PositionalSharding |
56 |
| - from einops import rearrange |
57 |
| - |
58 |
| - # sharding = PositionalSharding(jax.devices()) |
59 |
| - devices = jax.devices() |
60 |
| - # devices = [*jax.devices(), *jax.devices(backend='cpu')] |
61 |
| - n_device = len(devices) |
62 |
| - sharding = PositionalSharding(devices).reshape(1,-1,1,1)#.replicate() |
63 |
| - |
64 |
| - |
65 |
| - # from jax.experimental import mesh_utils |
66 |
| - # from jax.sharding import PartitionSpec as P, Mesh |
67 |
| - # from jax.sharding import NamedSharding |
68 |
| - # devices = np.array(jax.devices()) #mesh_utils.create_device_mesh((1,)) |
69 |
| - # mesh = Mesh(devices, axis_names=('x',)) |
70 |
| - # sharding = NamedSharding(mesh, P(None,None,'x',None)) |
71 | 60 |
|
72 |
| - # print(mesh) |
| 61 | + # devices = jax.devices(backend='cpu') |
| 62 | + # n_device = len(devices) |
| 63 | + # sharding = PositionalSharding(devices).reshape(-1,1,1,1)#.replicate() |
73 | 64 |
|
74 |
| - o_ref = fwd(q,k,v) |
| 65 | + devices = jax.devices(backend='gpu') |
| 66 | + with Mesh(devices, axis_names=('x',)) as mesh: |
| 67 | + sharding = NamedSharding(mesh, P(None,None,'x',None)) |
| 68 | + q = jax.device_put(q, sharding) |
| 69 | + k = jax.device_put(k, sharding) |
| 70 | + v = jax.device_put(v, sharding) |
| 71 | + # jax.debug.visualize_array_sharding(rearrange(q, 'n l h d -> n (l h d)')) |
| 72 | + print(fwd.lower(q,k,v).compile().as_text()) |
75 | 73 |
|
76 |
| - q = jax.device_put(q, sharding) |
77 |
| - k = jax.device_put(k, sharding) |
78 |
| - v = jax.device_put(v, sharding) |
79 |
| - jax.debug.visualize_array_sharding(rearrange(q, 'n l h d -> n (l h d)')) |
80 |
| - print(fwd.lower(q,k,v).compile().as_text()) |
81 |
| - o = fwd(q,k,v) |
82 |
| - jax.debug.visualize_array_sharding(rearrange(o, 'n l h d -> n (l h d)')) |
83 |
| - print((o - o_ref).std()) |
0 commit comments