Skip to content

Commit 63b5ee5

Browse files
committed
update snippet
1 parent 884acd2 commit 63b5ee5

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ topo_loss = TopoLoss(
2121
LaplacianPyramid(layer_name = 'fc',factor_h=3.0, factor_w=3.0),
2222
],
2323
)
24+
loss = topo_loss.compute(model=model)
25+
## >>> tensor(0.8407, grad_fn=<DivBackward0>)
26+
loss.backward()
2427

25-
print(topo_loss.compute(model=model, reduce_mean = True)) ## returns a single number as tensor for backward()
26-
print(topo_loss.compute(model=model, reduce_mean = False)) ## returns a dict with layer names as keys
28+
loss_dict = topo_loss.compute(model=model, reduce_mean = False) ## {"fc": }
29+
## >>> {'fc': tensor(0.8407, grad_fn=<MulBackward0>)}
2730
```
2831

2932
## Running tests

0 commit comments

Comments
 (0)