Skip to content

Commit 0123e0d

Browse files
committed
Clean up imports in flash.py, flash_sharding.py, and ring_attention.py.
1 parent 7cf539e commit 0123e0d

File tree

4 files changed

+7
-27
lines changed

4 files changed

+7
-27
lines changed

scripts/install-cuda-linux.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/bin/bash
2+
# Install CUDA on manylinux docker image.
23
set -eux
34

45
VER=${1:-12.8}

src/flash_attn_jax/flash.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,8 @@
88
from jax.interpreters import batching
99
from jax.interpreters import mlir
1010
from jax.interpreters import xla
11-
from jax.interpreters.mlir import ir
12-
from jax.lib import xla_client
13-
from jaxlib.hlo_helpers import custom_call
14-
from jax.experimental.custom_partitioning import custom_partitioning
1511
from jax.extend.core import Primitive
1612

17-
from jax.sharding import PartitionSpec as P
18-
from jax.sharding import Mesh
19-
from jax.sharding import NamedSharding
20-
2113
from einops import rearrange
2214
import einops
2315
import math
@@ -206,11 +198,14 @@ def custom_vjp(cls, nondiff_argnums=()):
206198
# bwd.
207199
@partial(custom_vjp, nondiff_argnums=(3,))
208200
class _flash_mha_vjp:
201+
@staticmethod
209202
def base(q,k,v,config):
210203
return _flash_mha_fwd(q,k,v, **config)[0]
204+
@staticmethod
211205
def fwd(q,k,v,config):
212206
out, lse = _flash_mha_fwd(q,k,v, **config)
213207
return out, (q,k,v,out,lse)
208+
@staticmethod
214209
def bwd(config, pack, dout):
215210
(q,k,v,out,lse) = pack
216211
dq, dk, dv = _flash_mha_bwd(dout, q, k, v, out, lse, **config)

src/flash_attn_jax/flash_sharding.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,7 @@
55
import numpy as np
66
import jax
77
import jax.numpy as jnp
8-
from jax import core, dtypes
9-
from jax.core import ShapedArray
10-
from jax.interpreters import batching
11-
from jax.interpreters import mlir
12-
from jax.interpreters import xla
138
from jax.interpreters.mlir import ir
14-
from jax.lib import xla_client
15-
from jaxlib.hlo_helpers import custom_call
169
from jax.experimental.custom_partitioning import custom_partitioning, SdyShardingRule, ArrayMapping, CompoundFactor
1710

1811
from jax.sharding import PartitionSpec as P
@@ -35,9 +28,9 @@ def is_replicated(sharding):
3528
return sharding.is_fully_replicated
3629
raise ValueError(f"Unsupported sharding type: {type(sharding)}")
3730

38-
def partition_fwd(softmax_scale, is_causal, window_size,
39-
mesh: Mesh,
40-
arg_shapes: List[jax.ShapeDtypeStruct],
31+
def partition_fwd(softmax_scale, is_causal, window_size,
32+
mesh: Mesh,
33+
arg_shapes: List[jax.ShapeDtypeStruct],
4134
result_shape: List[jax.ShapeDtypeStruct]):
4235
result_shardings = [x.sharding for x in result_shape],
4336
arg_shardings = [x.sharding for x in arg_shapes]

src/flash_attn_jax/ring_attention.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
import numpy as np
44
import jax
55
import jax.numpy as jnp
6-
from jax import core, dtypes
7-
from jax.core import ShapedArray
8-
from jax.interpreters import batching
9-
from jax.interpreters import mlir
10-
from jax.interpreters import xla
11-
from jax.interpreters.mlir import ir
12-
from jax.lib import xla_client
13-
from jaxlib.hlo_helpers import custom_call
14-
from jax.experimental.custom_partitioning import custom_partitioning
156

167
from einops import rearrange
178
import math

0 commit comments

Comments
 (0)