Skip to content

Commit bdfd050

Browse files
committed
Fix code formatting: apply black formatter
1 parent fd21fc8 commit bdfd050

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,20 @@ def _is_float32_tensor(self, node: fx.Node) -> bool:
134134
# Check type annotation if available
135135
if node.type is not None:
136136
type_str = str(node.type).lower()
137-
137+
138138
# Explicitly check for integer types - these should NOT be converted
139139
integer_types = ["long", "int", "short", "byte", "bool"]
140140
if any(int_type in type_str for int_type in integer_types):
141141
return False
142-
142+
143143
# Only return True if explicitly a floating point tensor
144144
# Check for explicit float types: FloatTensor, float32, float16, etc.
145145
float_indicators = ["float", "double", "half", "bfloat"]
146-
if any(float_indicator in type_str for float_indicator in float_indicators):
146+
if any(
147+
float_indicator in type_str for float_indicator in float_indicators
148+
):
147149
return True
148-
150+
149151
# For generic "Tensor" without explicit dtype, be conservative
150152
# Don't assume it's float32 - it might be integer
151153
return False

graph_net/torch/dtype_generalizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __call__(self, model_path: str) -> None:
8282
# Apply model_path_prefix if provided
8383
if self.model_path_prefix:
8484
model_path = str(Path(self.model_path_prefix) / model_path)
85-
85+
8686
# Parse the computation graph
8787
traced_model = parse_immutable_model_path_into_sole_graph_module(model_path)
8888

@@ -186,7 +186,9 @@ def _save_dtype_pass_names(
186186
model_path: Path to model directory
187187
"""
188188
graph_net_json_path = Path(model_path) / "graph_net.json"
189-
update_json(graph_net_json_path, {kDataTypeGeneralizationPasses: dtype_pass_names})
189+
update_json(
190+
graph_net_json_path, {kDataTypeGeneralizationPasses: dtype_pass_names}
191+
)
190192

191193

192194
class ApplyDataTypeGeneralizationPasses:
@@ -211,17 +213,16 @@ def __init__(self, config: Dict[str, Any]):
211213
self.output_dir = config.get("output_dir")
212214
if not self.output_dir:
213215
raise ValueError("output_dir is required in config")
214-
216+
215217
self.model_path_prefix = config.get("model_path_prefix", "")
216-
218+
217219
# model_runnable_predicator is required to ensure generated code is runnable
218220
if "model_runnable_predicator_filepath" not in config:
219221
raise ValueError(
220222
"model_runnable_predicator_filepath is required in config. "
221223
"Generated code must be validated."
222224
)
223225
self.model_runnable_predicator = self._make_model_runnable_predicator(config)
224-
225226
def _make_model_runnable_predicator(self, config: Dict[str, Any]):
226227
"""Create model runnable predicator from config."""
227228
module = load_module(config["model_runnable_predicator_filepath"])
@@ -245,7 +246,7 @@ def __call__(self, model_path: str) -> List[str]:
245246
# Apply model_path_prefix if provided
246247
if self.model_path_prefix:
247248
model_path = str(Path(self.model_path_prefix) / model_path)
248-
249+
249250
# Read pass names from graph_net.json
250251
dtype_pass_names = self._read_dtype_pass_names(model_path)
251252

0 commit comments

Comments
 (0)