File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed
Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -304,7 +304,7 @@ def update_params(
304304 grad_clip = hyperparameters ['grad_clip' ]
305305 else :
306306 grad_clip = None
307- dropout_rate = hyperparameters . dropout_rate
307+ dropout_rate = hyperparameters [ ' dropout_rate' ]
308308
309309 # Create shardings for each argument
310310 mesh = jax .sharding .Mesh (jax .devices (), ('batch' ))
Original file line number Diff line number Diff line change @@ -304,7 +304,7 @@ def update_params(
304304 grad_clip = hyperparameters ['grad_clip' ]
305305 else :
306306 grad_clip = None
307- dropout_rate = hyperparameters . dropout_rate
307+ dropout_rate = hyperparameters [ ' dropout_rate' ]
308308
309309 # Create shardings for each argument
310310 mesh = jax .sharding .Mesh (jax .devices (), ('batch' ))
You can’t perform that action at this time.
0 commit comments