Skip to content

Commit 8b6e2b5

Browse files
committed
jsonify modules better
1 parent b7ed4ae commit 8b6e2b5

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

gbmi/utils/hashing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import numpy
2929
import torch
30+
import transformer_lens
3031
from transformer_lens import HookedTransformer
3132

3233
# Implemented for https://github.com/lemon24/reader/issues/179
@@ -149,6 +150,16 @@ def _json_default(
149150
exclude_filter=exclude_filter,
150151
dictify_by_default=dictify_by_default,
151152
)
153+
elif isinstance(thing, torch.nn.Module):
154+
return _json_dumps(
155+
{
156+
"type": type(thing),
157+
"repr": repr(thing),
158+
"module.parameters": list(thing.parameters()),
159+
},
160+
exclude_filter=exclude_filter,
161+
dictify_by_default=dictify_by_default,
162+
)
152163
elif isinstance(thing, type):
153164
return f"{thing.__module__}.{thing.__name__}"
154165
elif (

0 commit comments

Comments
 (0)