Skip to content

Commit 62e1851

Browse files
jamt9000rom1504
authored andcommitted
Fix torch._C.Node attribute access (openai#372)
Attribute access with subscripting would previously work due to patching in pytorch/pytorch#82511 but this has been removed. This commit uses the fix proposed in pytorch/pytorch#82628 to define a helper method to call the appropriate access method.
1 parent a44317f commit 62e1851

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

clip/clip.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
146146
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
147147
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
148148

149+
def _node_get(node: torch._C.Node, key: str):
150+
"""Gets attributes of a node which is polymorphic over return type.
151+
152+
From https://github.com/pytorch/pytorch/pull/82628
153+
"""
154+
sel = node.kindOf(key)
155+
return getattr(node, sel)(key)
156+
149157
def patch_device(module):
150158
try:
151159
graphs = [module.graph] if hasattr(module, "graph") else []
@@ -157,7 +165,7 @@ def patch_device(module):
157165

158166
for graph in graphs:
159167
for node in graph.findAllNodes("prim::Constant"):
160-
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
168+
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
161169
node.copyAttributes(device_node)
162170

163171
model.apply(patch_device)
@@ -183,7 +191,7 @@ def patch_float(module):
183191
for node in graph.findAllNodes("aten::to"):
184192
inputs = list(node.inputs())
185193
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
186-
if inputs[i].node()["value"] == 5:
194+
if _node_get(inputs[i].node(), "value") == 5:
187195
inputs[i].node().copyAttributes(float_node)
188196

189197
model.apply(patch_float)

0 commit comments

Comments
 (0)