Skip to content

Commit fbf07e4

Browse files
committed
Update mnist_ResNet.py
1 parent 181ae4f commit fbf07e4

File tree

1 file changed

+1
-20
lines changed

1 file changed

+1
-20
lines changed

examples/training_ann_models/mnist_ResNet.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,6 @@
1414

1515
bm.set_environment(mode=bm.training_mode, dt=1.)
1616

17-
class Partial(bm.FunAsObject):
18-
def __init__(
19-
self,
20-
fun,
21-
*args,
22-
child_objs = None,
23-
dyn_vars = None,
24-
**keywords
25-
):
26-
super().__init__(f=fun, child_objs=child_objs, dyn_vars=dyn_vars)
27-
28-
self.fun = fun
29-
self.args = args
30-
self.keywords = keywords
31-
32-
def __call__(self, /, *args, **keywords):
33-
keywords = {**self.keywords, **keywords}
34-
return self.fun(*self.args, *args, **keywords)
35-
3617

3718
class BasicBlock(bp.DynamicalSystem):
3819
expansion = 1
@@ -226,7 +207,7 @@ def main():
226207
y_test = bm.asarray(test_set.targets, dtype=bm.int_)
227208

228209
with bm.training_environment():
229-
net = Partial(ResNet18(num_classes=10), is_feat=False, preact=False)
210+
net = ResNet18(num_classes=10)
230211

231212
# loss function
232213
@bm.to_object(child_objs=net)

0 commit comments

Comments
 (0)