Skip to content

Commit 707c113

Browse files
committed
Fix naming of results in ODS generator
This commit fixes the naming of results in the torch ODS generator when dealing with multiple results. In particular, this commit appends an index to each result name to guarantee that they are all unique.
1 parent 829cf8a commit 707c113

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,11 +1391,11 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
13911391
Torch_FloatType:$eps
13921392
);
13931393
let results = (outs
1394-
AnyTorchTensorType:$layer_norm,
1395-
AnyTorchTensorType:$mean,
1396-
AnyTorchTensorType:$variance
1394+
AnyTorchTensorType:$result0,
1395+
AnyTorchTensorType:$result1,
1396+
AnyTorchTensorType:$result2
13971397
);
1398-
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($layer_norm) `,` type($mean) `,` type($variance)";
1398+
let assemblyFormat = "$input `,` $normalized_shape `,` $weight `,` $bias `,` $eps attr-dict `:` type($input) `,` type($normalized_shape) `,` type($weight) `,` type($bias) `,` type($eps) `->` type($result0) `,` type($result1) `,` type($result2)";
13991399
}
14001400

14011401
def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
297297
emitter = TextEmitter(f)
298298
p = lambda *args: emitter.print(*args)
299299
op_name, td_def_name = operator.get_mlir_names()
300+
301+
# Generate unique result names for ops with nameless results
302+
multiple_results = len(operator.returns) > 1
303+
generic_result_name = lambda i: "result" + (str(i) if multiple_results else "")
304+
300305
p(f"def {td_def_name} : Torch_Op<{emitter.quote(op_name)}, [")
301306
with emitter.indent():
302307
with emitter.indent():
@@ -321,8 +326,8 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
321326
p("Variadic<AnyTorchType>:$results")
322327
else:
323328
p(",\n".join([
324-
f"""{get_ods_type(ret["type"])}:${ret["name"] or "result"}"""
325-
for ret in operator.returns
329+
f"""{get_ods_type(ret["type"])}:${ret["name"] or generic_result_name(e)}"""
330+
for e, ret in enumerate(operator.returns)
326331
]))
327332
p(");")
328333

@@ -338,8 +343,8 @@ def raw_emit_op(operator: JitOperator, f: TextIO, *, traits: List[str],
338343
assembly_result_types = "type($results)"
339344
else:
340345
assembly_result_types = " `,` ".join(
341-
f"""type(${ret["name"] or "result"})"""
342-
for ret in operator.returns)
346+
f"""type(${ret["name"] or generic_result_name(e)})"""
347+
for e, ret in enumerate(operator.returns))
343348
if assembly_operand_types and assembly_result_types:
344349
maybe_arrow = " `->` "
345350
else:

0 commit comments

Comments
 (0)