Skip to content

Commit 84b78f6

Browse files
committed
fix(jax): support jax 0.7+
1 parent 8cda2f7 commit 84b78f6

File tree

13 files changed

+1289
-914
lines changed

13 files changed

+1289
-914
lines changed

src/equimo/experimental/text.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Callable, Optional, Sequence
2+
from typing import Callable, Optional, Sequence, Tuple
33
from urllib.parse import urlparse
44

55
import equinox as eqx
@@ -136,7 +136,7 @@ class Transformer(eqx.Module):
136136
blocks: A list of `AttentionBlock` instances forming the transformer stack.
137137
"""
138138

139-
blocks: list[AttentionBlock]
139+
blocks: Tuple[AttentionBlock, ...]
140140

141141
def __init__(
142142
self,
@@ -164,7 +164,7 @@ def __init__(
164164

165165
act_layer = get_act(act_layer)
166166

167-
self.blocks = [
167+
self.blocks = tuple(
168168
AttentionBlock(
169169
dim=dim,
170170
num_heads=num_heads,
@@ -173,7 +173,7 @@ def __init__(
173173
key=keys[i],
174174
)
175175
for i in range(depth)
176-
]
176+
)
177177

178178
def __call__(
179179
self,

src/equimo/layers/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ class RFAttention(eqx.Module):
16711671
eps: float = eqx.field(static=True)
16721672

16731673
qkv: eqx.nn.Conv2d
1674-
aggreg: list[eqx.nn.Conv2d]
1674+
aggreg: Tuple[eqx.nn.Conv2d, ...]
16751675
proj: SingleConvBlock
16761676

16771677
def __init__(
@@ -1709,7 +1709,7 @@ def __init__(
17091709
use_bias=use_bias,
17101710
key=key_qkv,
17111711
)
1712-
self.aggreg = [
1712+
self.aggreg = tuple(
17131713
eqx.nn.Conv2d(
17141714
in_channels=3 * total_dim,
17151715
out_channels=3 * total_dim,
@@ -1720,7 +1720,7 @@ def __init__(
17201720
use_bias=use_bias,
17211721
)
17221722
for scale in scales
1723-
]
1723+
)
17241724
# TODO: test different normalizations
17251725
self.proj = SingleConvBlock(
17261726
in_channels=self.total_dim,

src/equimo/layers/convolution.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class C2f(eqx.Module):
409409

410410
conv1: SingleConvBlock
411411
conv2: SingleConvBlock
412-
blocks: list[ConvBottleneck]
412+
blocks: Tuple[ConvBottleneck, ...]
413413

414414
def __init__(
415415
self,
@@ -445,7 +445,7 @@ def __init__(
445445
key=key_conv2,
446446
)
447447

448-
self.blocks = [
448+
self.blocks = tuple(
449449
ConvBottleneck(
450450
in_channels=self.hidden_channels,
451451
out_channels=self.hidden_channels,
@@ -456,7 +456,7 @@ def __init__(
456456
key=key_blocks[i],
457457
)
458458
for i in range(n)
459-
]
459+
)
460460

461461
def __call__(
462462
self,
@@ -589,7 +589,7 @@ class C3k2(eqx.Module):
589589

590590
conv1: SingleConvBlock
591591
conv2: SingleConvBlock
592-
blocks: list[ConvBottleneck] | list[C3k]
592+
blocks: Tuple[ConvBottleneck, ...] | Tuple[C3k, ...]
593593

594594
def __init__(
595595
self,
@@ -627,7 +627,7 @@ def __init__(
627627
)
628628

629629
if c3k:
630-
self.blocks = [
630+
self.blocks = tuple(
631631
C3k(
632632
in_channels=self.hidden_channels,
633633
out_channels=self.hidden_channels,
@@ -637,9 +637,9 @@ def __init__(
637637
key=key_blocks[i],
638638
)
639639
for i in range(n)
640-
]
640+
)
641641
else:
642-
self.blocks = [
642+
self.blocks = tuple(
643643
ConvBottleneck(
644644
in_channels=self.hidden_channels,
645645
out_channels=self.hidden_channels,
@@ -648,7 +648,7 @@ def __init__(
648648
key=key_blocks[i],
649649
)
650650
for i in range(n)
651-
]
651+
)
652652

653653
def __call__(
654654
self,
@@ -1078,12 +1078,12 @@ class GenericGhostModule(eqx.Module):
10781078
cheap_operation: eqx.nn.Conv2d
10791079

10801080
# Training
1081-
primary_rpr_conv: list[eqx.nn.Conv2d]
1081+
primary_rpr_conv: Tuple[eqx.nn.Conv2d, ...]
10821082
primary_rpr_scale: eqx.nn.Conv2d | eqx.nn.Identity
10831083
primary_shared_norm: eqx.nn.GroupNorm
10841084
primary_activation: Callable
10851085

1086-
cheap_rpr_conv: list[eqx.nn.Conv2d]
1086+
cheap_rpr_conv: Tuple[eqx.nn.Conv2d, ...]
10871087
cheap_rpr_scale: eqx.nn.Conv2d | eqx.nn.Identity
10881088
cheap_shared_norm: eqx.nn.GroupNorm
10891089
cheap_activation: Callable
@@ -1156,7 +1156,7 @@ def __init__(
11561156

11571157
# Primary training branches
11581158
init_num_groups = nearest_power_of_2_divisor(init_channels, 32)
1159-
self.primary_rpr_conv = [
1159+
self.primary_rpr_conv = tuple(
11601160
eqx.nn.Conv2d(
11611161
in_channels=in_channels,
11621162
out_channels=init_channels,
@@ -1167,7 +1167,7 @@ def __init__(
11671167
key=key_ps[i],
11681168
)
11691169
for i in range(num_conv_branches)
1170-
]
1170+
)
11711171
self.primary_rpr_scale = (
11721172
eqx.nn.Conv2d(
11731173
in_channels=in_channels,
@@ -1186,7 +1186,7 @@ def __init__(
11861186

11871187
# Cheap training branches (depthwise)
11881188
newchannels_num_groups = nearest_power_of_2_divisor(new_channels, 32)
1189-
self.cheap_rpr_conv = [
1189+
self.cheap_rpr_conv = tuple(
11901190
eqx.nn.Conv2d(
11911191
in_channels=init_channels,
11921192
out_channels=new_channels,
@@ -1198,7 +1198,7 @@ def __init__(
11981198
key=key_cs[i],
11991199
)
12001200
for i in range(self.num_conv_branches)
1201-
]
1201+
)
12021202
self.cheap_rpr_scale = (
12031203
eqx.nn.Conv2d(
12041204
in_channels=init_channels,
@@ -1344,7 +1344,7 @@ class GhostBottleneck(eqx.Module):
13441344
ghost2: "GenericGhostModule"
13451345

13461346
dw_conv: eqx.nn.Conv2d | eqx.nn.Identity
1347-
dw_rpr_conv: list[eqx.nn.Conv2d] # depthwise conv branches (no bias)
1347+
dw_rpr_conv: Tuple[eqx.nn.Conv2d, ...] # depthwise conv branches (no bias)
13481348
dw_rpr_scale: eqx.nn.Conv2d | eqx.nn.Identity # optional 1x1 depthwise (no bias)
13491349
dw_shared_norm: eqx.nn.GroupNorm | eqx.nn.Identity
13501350

@@ -1393,7 +1393,7 @@ def __init__(
13931393
# Depthwise stage (only if stride > 1)
13941394
if stride > 1:
13951395
# Training-time branches (depthwise, no bias); no activation; shared GN after sum
1396-
self.dw_rpr_conv = [
1396+
self.dw_rpr_conv = tuple(
13971397
eqx.nn.Conv2d(
13981398
in_channels=mid_channels,
13991399
out_channels=mid_channels,
@@ -1405,7 +1405,7 @@ def __init__(
14051405
key=k_dw_list[i],
14061406
)
14071407
for i in range(3)
1408-
]
1408+
)
14091409
# Optional scale branch (1x1, depthwise, stride=stride)
14101410
self.dw_rpr_scale = (
14111411
eqx.nn.Conv2d(

src/equimo/layers/sharing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import List, Optional
2+
from typing import List, Optional, Tuple
33

44
import equinox as eqx
55
import jax
@@ -72,8 +72,8 @@ class LayerSharing(eqx.Module):
7272

7373
repeat: int = eqx.field(static=True)
7474

75-
loras: List[LoRA]
76-
dropouts: List[eqx.nn.Dropout]
75+
loras: Tuple[LoRA, ...]
76+
dropouts: Tuple[eqx.nn.Dropout, ...]
7777
f: eqx.Module
7878

7979
def __init__(
@@ -93,8 +93,8 @@ def __init__(
9393
keys = jr.split(key, repeat)
9494
self.repeat = repeat
9595

96-
self.dropouts = [eqx.nn.Dropout(drop_rate) for i in range(self.repeat)]
97-
self.loras = [
96+
self.dropouts = tuple(eqx.nn.Dropout(drop_rate) for i in range(self.repeat))
97+
self.loras = tuple(
9898
LoRA(
9999
in_features=dim,
100100
out_features=dim,
@@ -103,7 +103,7 @@ def __init__(
103103
key=keys[i],
104104
)
105105
for i in range(self.repeat)
106-
]
106+
)
107107

108108
self.f = f
109109

src/equimo/models/fastervit.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Literal, Optional
1+
from typing import Callable, List, Literal, Optional, Tuple
22

33
import equinox as eqx
44
import jax
@@ -157,7 +157,7 @@ class BlockChunk(eqx.Module):
157157
window_size: bool = eqx.field(static=True)
158158
do_gt: bool = eqx.field(static=True)
159159

160-
blocks: List[eqx.Module]
160+
blocks: Tuple[eqx.Module, ...]
161161
downsample: eqx.Module
162162
global_tokenizer: Optional[TokenInitializer]
163163

@@ -195,7 +195,7 @@ def __init__(
195195
k for k, v in kwargs.items() if isinstance(v, list) and len(v) == depth
196196
]
197197

198-
self.blocks = []
198+
blocks = []
199199
for i in range(depth):
200200
config = kwargs | {k: kwargs[k][i] for k in keys_to_spread}
201201

@@ -210,14 +210,15 @@ def __init__(
210210
}
211211

212212
wrapper = LayerSharingWithCT if self.is_hat else LayerSharing
213-
self.blocks.append(
213+
blocks.append(
214214
wrapper(
215215
dim=kwargs.get("dim"),
216216
f=block(**config, key=block_subkeys[i]),
217217
repeat=repeat,
218218
key=block_subkeys[i],
219219
),
220220
)
221+
self.blocks = tuple(blocks)
221222

222223
self.downsample = downsampler(dim=kwargs.get("dim"), key=key_ds)
223224

@@ -331,7 +332,7 @@ class FasterViT(eqx.Module):
331332
"""
332333

333334
patch_embed: ConvPatchEmbed
334-
blocks: List[eqx.Module]
335+
blocks: Tuple[eqx.Module, ...]
335336
norm: eqx.Module
336337
head: eqx.Module
337338

@@ -398,7 +399,7 @@ def __init__(
398399
hat = to_list(hat, n_chunks)
399400
attn_layer = to_list(attn_layer, n_chunks)
400401
window_size = to_list(window_size, n_chunks)
401-
self.blocks = [
402+
self.blocks = tuple(
402403
BlockChunk(
403404
block=ConvBlock if i < 2 else HATBlock,
404405
repeat=repeat,
@@ -427,7 +428,7 @@ def __init__(
427428
key=block_subkeys[i],
428429
)
429430
for i, depth in enumerate(depths)
430-
]
431+
)
431432

432433
num_features = int(dim * 2 ** (len(depths) - 1))
433434
self.norm = norm_layer(num_features)

src/equimo/models/lowformer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class BlockChunk(eqx.Module):
1717
residuals: list[bool] = eqx.field(static=True)
18-
blocks: list[DSConv | MBConv | LowFormerBlock]
18+
blocks: Tuple[DSConv | MBConv | LowFormerBlock, ...]
1919

2020
def __init__(
2121
self,
@@ -123,7 +123,7 @@ def __init__(
123123
)
124124
residuals.append(False)
125125

126-
self.blocks = blocks
126+
self.blocks = tuple(blocks)
127127
self.residuals = residuals
128128

129129
def __call__(
@@ -146,7 +146,7 @@ def __call__(
146146

147147
class LowFormer(eqx.Module):
148148
input_stem: eqx.nn.Sequential
149-
blocks: list[BlockChunk]
149+
blocks: Tuple[BlockChunk, ...]
150150
head: eqx.nn.Linear | eqx.nn.Identity
151151

152152
def __init__(
@@ -220,7 +220,7 @@ def __init__(
220220
]
221221
)
222222

223-
self.blocks = [
223+
self.blocks = tuple(
224224
BlockChunk(
225225
in_channels=widths[i - 1] if i > 0 else width_stem,
226226
out_channels=widths[i],
@@ -241,7 +241,7 @@ def __init__(
241241
for i, (depth, att_stride, block_type, key_block) in enumerate(
242242
zip(depths, att_strides, block_types, key_blocks)
243243
)
244-
]
244+
)
245245

246246
self.head = (
247247
eqx.nn.Linear(

src/equimo/models/mlla.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Tuple
22

33
import equinox as eqx
44
import jax
@@ -38,7 +38,7 @@ class Mlla(eqx.Module):
3838

3939
patch_embed: eqx.Module
4040
pos_drop: eqx.Module
41-
blocks: List[eqx.Module]
41+
blocks: Tuple[eqx.Module, ...]
4242
head: eqx.Module
4343

4444
def __init__(
@@ -52,7 +52,7 @@ def __init__(
5252
patch_size: int = 4,
5353
depths: List[int] = [2, 2, 6, 2],
5454
num_heads: List[int] = [3, 6, 12, 24],
55-
attentions_layers: List[eqx.Module] | eqx.Module = LinearAttention,
55+
attentions_layers: Tuple[eqx.Module, ...] | eqx.Module = LinearAttention,
5656
drop_rate: float = 0.0,
5757
drop_path_rate: float = 0.0,
5858
drop_path_uniform: bool = False,
@@ -101,8 +101,8 @@ def __init__(
101101
)
102102

103103
num_heads = to_list(num_heads, n_chunks)
104-
attentions_layers = to_list(attentions_layers, n_chunks)
105-
self.blocks = [
104+
attentions_layers = tuple(to_list(attentions_layers, n_chunks))
105+
self.blocks = tuple(
106106
BlockChunk(
107107
block=MllaBlock,
108108
repeat=repeat,
@@ -124,7 +124,7 @@ def __init__(
124124
key=block_subkeys[i],
125125
)
126126
for i, depth in enumerate(depths)
127-
]
127+
)
128128

129129
self.norm = eqx.nn.LayerNorm(self.num_features)
130130
self.head = (

0 commit comments

Comments
 (0)