Skip to content

Commit 01f909e

Browse files
authored
fix autocast (#235)
1 parent 2519495 commit 01f909e

File tree

8 files changed

+7508
-9
lines changed

8 files changed

+7508
-9
lines changed

graph_net/torch/extractor.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
torch._dynamo.config.capture_scalar_outputs = True
99
torch._dynamo.config.capture_dynamic_output_shape_ops = True
10+
torch._dynamo.config.capture_sparse_compute = True
11+
torch._dynamo.config.raise_on_ctx_manager_usage = False
12+
torch._dynamo.config.allow_rnn = True
1013

1114

1215
def extract(name, dynamic=True, mut_graph_codes=None, placeholder_auto_rename=False):
@@ -84,6 +87,11 @@ def wrapper(model: torch.nn.Module):
8487
class GraphExtractor:
8588
def __init__(self):
8689
self.subgraph_counter = 0
90+
self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
91+
if not self.workspace_path:
92+
raise EnvironmentError(
93+
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
94+
)
8795

8896
def move_files(self, source_dir, target_dir):
8997
os.makedirs(target_dir, exist_ok=True)
@@ -94,13 +102,8 @@ def move_files(self, source_dir, target_dir):
94102
shutil.move(source_path, target_path)
95103

96104
def __call__(self, gm: torch.fx.GraphModule, sample_inputs):
97-
# 1. Get workspace path
98-
workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE")
99-
if not workspace_path:
100-
raise EnvironmentError(
101-
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
102-
)
103-
model_path = os.path.join(workspace_path, name)
105+
# 1. Get model path
106+
model_path = os.path.join(self.workspace_path, name)
104107
os.makedirs(model_path, exist_ok=True)
105108

106109
if self.subgraph_counter == 0:
@@ -140,6 +143,15 @@ def try_rename_placeholder(node):
140143
input = torch.tensor(4)
141144
params[node.target] = input
142145
input_idx += 1
146+
147+
if node.op == "call_function" and hasattr(node.target, "__name__"):
148+
if node.target.__name__ in [
149+
"_enter_autocast",
150+
"_exit_autocast",
151+
]:
152+
node.replace_all_uses_with(node.args[0])
153+
gm.graph.erase_node(node)
154+
143155
assert input_idx == len(sample_inputs)
144156
if mut_graph_codes is not None:
145157
assert isinstance(mut_graph_codes, list)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fdf650e0bdc3326eccec3edf5adcc6683a1d516b7d8192a39e3fd1d27bfc423f
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"framework": "torch",
3+
"num_devices_required": 1,
4+
"num_nodes_required": 1,
5+
"dynamic": false,
6+
"model_name": "Qwen/Qwen1.5-0.5B"
7+
}

samples/transformers-auto-model/Qwen1.5-0.5B/input_meta.py

Whitespace-only changes.

samples/transformers-auto-model/Qwen1.5-0.5B/input_tensor_constraints.py

Whitespace-only changes.

samples/transformers-auto-model/Qwen1.5-0.5B/model.py

Lines changed: 4498 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)