Skip to content

Commit 84b2a1d

Browse files
committed
Update attention for Gemma3 vit
1 parent 238a410 commit 84b2a1d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/MaxText/layers/gemma3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(
382382
matmul_precision=self.config.matmul_precision,
383383
rngs=self.rngs,
384384
)
385-
self.Dropout_0 = nnx.Dropout(rate=self.config.dropout_rate)
385+
self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs)
386386
self.Dense_1 = DenseGeneral(
387387
in_features_shape=self.config.intermediate_size_for_vit,
388388
out_features_shape=self.config.hidden_size_for_vit,
@@ -452,12 +452,12 @@ def __init__(
452452
config=self.config,
453453
rngs=self.rngs,
454454
)
455-
self.Dropout_0 = nnx.Dropout(self.config.dropout_rate, rngs=self.rngs)
455+
self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs)
456456

457457
def __call__(self, x: jax.Array, deterministic: bool = False) -> jax.Array:
458458
y = self.LayerNorm_0(x)
459459

460-
y = self.MultiHeadDotProductAttention_0(inputs_q=y, inputs_kv=y, deterministic=deterministic)
460+
y, _ = self.MultiHeadDotProductAttention_0(inputs_q=y, inputs_kv=y, deterministic=deterministic)
461461
y = self.Dropout_0(y, deterministic=deterministic)
462462
x = x + y
463463

@@ -634,7 +634,7 @@ def __init__(
634634
width=self.config.hidden_size_for_vit,
635635
dtype=self.config.dtype_mm,
636636
)
637-
self.Dropout_0 = nnx.Dropout(self.config.dropout_rate, rngs=self.rngs)
637+
self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs)
638638
self.Transformer = Encoder(
639639
config=self.config,
640640
mesh=self.mesh,

0 commit comments

Comments
 (0)