Skip to content

Commit 37b70fb

Browse files
committed
update advanced tutorials
1 parent 54823db commit 37b70fb

File tree

11 files changed

+1046
-591
lines changed

11 files changed

+1046
-591
lines changed

brainpy/algorithms/offline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,16 @@ def cond_fun(a):
149149
i < self.max_iter).value
150150

151151
def body_fun(a):
152-
i, par_old, par_new = a
152+
i, _, par_new = a
153153
# Gradient of regularization loss w.r.t w
154-
y_pred = inputs.dot(par_old)
154+
y_pred = inputs.dot(par_new)
155155
grad_w = bm.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new)
156156
# Update the weights
157157
par_new2 = par_new - self.learning_rate * grad_w
158158
return i + 1, par_new, par_new2
159159

160160
# Tune parameters for n iterations
161-
r = while_loop(cond_fun, body_fun, (0, w, w + 1e-8))
161+
r = while_loop(cond_fun, body_fun, (0, w - 1e-8, w))
162162
return r[-1]
163163

164164
def predict(self, W, X):

brainpy/math/operators/op_register.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ def __call__(self, *args, **kwargs):
9494

9595

9696
def register_op(
97-
op_name: str,
97+
name: str,
98+
eval_shape: Union[Callable, ShapedArray, Sequence[ShapedArray]],
9899
cpu_func: Callable,
99-
out_shapes: Union[Callable, ShapedArray, Sequence[ShapedArray]],
100100
gpu_func: Callable = None,
101101
apply_cpu_func_to_gpu: bool = False
102102
):
@@ -105,13 +105,13 @@ def register_op(
105105
106106
Parameters
107107
----------
108-
op_name: str
108+
name: str
109109
Name of the operators.
110110
cpu_func: Callble
111111
A callable numba-jitted function or pure function (can be lambda function) running on CPU.
112112
gpu_func: Callable, default = None
113113
A callable cuda-jitted kernel running on GPU.
114-
out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None
114+
eval_shape: Callable, ShapedArray, Sequence[ShapedArray], default = None
115115
Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or
116116
a sequence of `ShapedArray`. If it is a function, it takes as input the argument
117117
shapes and dtypes and should return correct output shapes of `ShapedArray`.
@@ -123,10 +123,10 @@ def register_op(
123123
A jitable JAX function.
124124
"""
125125
_check_brainpylib(register_op.__name__)
126-
f = brainpylib.register_op(op_name,
126+
f = brainpylib.register_op(name,
127127
cpu_func=cpu_func,
128128
gpu_func=gpu_func,
129-
out_shapes=out_shapes,
129+
out_shapes=eval_shape,
130130
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
131131

132132
def fixed_op(*inputs):

brainpy/math/operators/tests/test_op_register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def event_sum_op(outs, ins):
2323
outs[index] += v
2424

2525

26-
event_sum = bm.register_op(op_name='event_sum', cpu_func=event_sum_op, out_shapes=abs_eval)
26+
event_sum = bm.register_op(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval)
2727
event_sum = bm.jit(event_sum)
2828

2929

docs/index.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,10 @@ The code of BrainPy is open-sourced at GitHub:
7777
:caption: Advanced Tutorials
7878

7979
tutorial_advanced/variables
80-
tutorial_advanced/base
80+
tutorial_advanced/base_and_collector
8181
tutorial_advanced/compilation
8282
tutorial_advanced/differentiation
83-
tutorial_advanced/control_flows
84-
tutorial_advanced/low-level_operator_customization
83+
tutorial_advanced/operator_customization
8584
tutorial_advanced/interoperation
8685

8786

docs/tutorial_advanced/base.ipynb renamed to docs/tutorial_advanced/base_and_collector.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
}
1010
},
1111
"source": [
12-
"# Base Class"
12+
"# Fundamental Base and Collector Objects"
1313
]
1414
},
1515
{

docs/tutorial_advanced/differentiation.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
}
1010
},
1111
"source": [
12-
"# Autograd for Class Variables"
12+
"# Automatic Differentiation for Class Variables"
1313
]
1414
},
1515
{

0 commit comments

Comments
 (0)