Skip to content

Commit a479334

Browse files
committed
Fix: process torch.device and torch.dtype in inputs
1 parent 6529dac commit a479334

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

graph_net/torch/decompose_util.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,14 @@ def get_args_node(arg):
248248
yield arg.start
249249
yield arg.stop
250250
yield arg.step
251+
elif isinstance(arg, torch.device):
252+
pass
253+
elif isinstance(arg, torch.dtype):
254+
pass
251255
else:
252-
assert isinstance(arg, (int, bool, float, str, type(None))), f"{type(arg)=}"
256+
assert isinstance(
257+
arg, (int, bool, float, str, type(None), torch.device, torch.dtype)
258+
), f"{type(arg)=}"
253259

254260
def get_args_node_and_self_node(node):
255261
for arg in node.args:

0 commit comments

Comments
 (0)