Skip to content

Commit b53bf3b

Browse files
Merge pull request #2733 from AI-Hypercomputer:hengtaoguo-kvcache
PiperOrigin-RevId: 836754833
2 parents 6191433 + 84b2a1d commit b53bf3b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/MaxText/layers/gemma3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def __init__(
382382
matmul_precision=self.config.matmul_precision,
383383
rngs=self.rngs,
384384
)
385-
self.Dropout_0 = nnx.Dropout(rate=self.config.dropout_rate)
385+
self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs)
386386
self.Dense_1 = DenseGeneral(
387387
in_features_shape=self.config.intermediate_size_for_vit,
388388
out_features_shape=self.config.hidden_size_for_vit,
@@ -452,12 +452,12 @@ def __init__(
452452
config=self.config,
453453
rngs=self.rngs,
454454
)
455-
self.Dropout_0 = nnx.Dropout(self.config.dropout_rate, rngs=self.rngs)
455+
self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs)
456456

457457
def __call__(self, x: jax.Array, deterministic: bool = False) -> jax.Array:
458458
y = self.LayerNorm_0(x)
459459

460-
y = self.MultiHeadDotProductAttention_0(inputs_q=y, inputs_kv=y, deterministic=deterministic)
460+
y, _ = self.MultiHeadDotProductAttention_0(inputs_q=y, inputs_kv=y, deterministic=deterministic)
461461
y = self.Dropout_0(y, deterministic=deterministic)
462462
x = x + y
463463

@@ -634,7 +634,7 @@ def __init__(
634634
width=self.config.hidden_size_for_vit,
635635
dtype=self.config.dtype_mm,
636636
)
637-
self.Dropout_0 = nnx.Dropout(self.config.dropout_rate, rngs=self.rngs)
637+
self.Dropout_0 = Dropout(rate=self.config.dropout_rate, rngs=self.rngs)
638638
self.Transformer = Encoder(
639639
config=self.config,
640640
mesh=self.mesh,

0 commit comments

Comments
 (0)