Skip to content

Commit c251dec

Browse files
Update neural_network_future.py
Syntax errors...
1 parent 0730fad commit c251dec

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

cerebros/neuralnetworkfuture/neural_network_future.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def compile_neural_network(self):
341341
weight_decay=0.004, # Add weight decay parameter
342342
gradient_accumulation_steps=self.gradient_accumulation_steps
343343
),
344-
jit_compile=Tree) # jit_compile)
344+
jit_compile=True) # jit_compile)
345345
elif self.gradient_accumulation_steps == 1:
346346
self.materialized_neural_network.compile(
347347
loss=self.loss,
@@ -350,7 +350,7 @@ def compile_neural_network(self):
350350
learning_rate=self.learning_rate,
351351
weight_decay=0.004, # Add weight decay parameter
352352
),
353-
jit_compile=jit_compile)
353+
jit_compile=True) # jit_compile=jit_compile)
354354
else:
355355
raise ValueError("gradient_accumulation_steps must be an int >= 0. You set it as {self.gradient_accumulation_steps} type {type(self.gradient_accumulation_steps)}")
356356

0 commit comments

Comments
 (0)