Skip to content

Commit d67afa9

Browse files
authored
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543)
1 parent b26797c commit d67afa9

File tree

4 files changed

+118
-50
lines changed

4 files changed

+118
-50
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,55 +2102,6 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
21022102
}];
21032103
}
21042104

2105-
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
2106-
AllowsTypeRefinement,
2107-
HasValueSemantics,
2108-
ReadOnly
2109-
]> {
2110-
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
2111-
let arguments = (ins
2112-
AnyTorchTensorType:$self,
2113-
AnyTorchTensorType:$mask,
2114-
AnyTorchTensorType:$value
2115-
);
2116-
let results = (outs
2117-
AnyTorchTensorType:$result
2118-
);
2119-
let hasCustomAssemblyFormat = 1;
2120-
let extraClassDefinition = [{
2121-
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
2122-
return parseDefaultTorchOp(parser, result, 3, 1);
2123-
}
2124-
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
2125-
printDefaultTorchOp(printer, *this, 3, 1);
2126-
}
2127-
}];
2128-
}
2129-
2130-
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
2131-
IsTrailingUnderscoreInplaceVariant,
2132-
AllowsTypeRefinement
2133-
]> {
2134-
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
2135-
let arguments = (ins
2136-
Torch_NonValueTensorType:$self,
2137-
Torch_NonValueTensorType:$mask,
2138-
Torch_NonValueTensorType:$value
2139-
);
2140-
let results = (outs
2141-
Torch_NonValueTensorType:$result
2142-
);
2143-
let hasCustomAssemblyFormat = 1;
2144-
let extraClassDefinition = [{
2145-
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
2146-
return parseDefaultTorchOp(parser, result, 3, 1);
2147-
}
2148-
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
2149-
printDefaultTorchOp(printer, *this, 3, 1);
2150-
}
2151-
}];
2152-
}
2153-
21542105
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
21552106
AllowsTypeRefinement,
21562107
HasValueSemantics,
@@ -3658,6 +3609,56 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
36583609
}];
36593610
}
36603611

3612+
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
3613+
AllowsTypeRefinement,
3614+
HasValueSemantics,
3615+
ReadOnly
3616+
]> {
3617+
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
3618+
let arguments = (ins
3619+
AnyTorchTensorType:$self,
3620+
AnyTorchTensorType:$mask,
3621+
AnyTorchTensorType:$value
3622+
);
3623+
let results = (outs
3624+
AnyTorchTensorType:$result
3625+
);
3626+
let hasCustomAssemblyFormat = 1;
3627+
let extraClassDefinition = [{
3628+
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
3629+
return parseDefaultTorchOp(parser, result, 3, 1);
3630+
}
3631+
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
3632+
printDefaultTorchOp(printer, *this, 3, 1);
3633+
}
3634+
}];
3635+
let hasCanonicalizer = 1;
3636+
}
3637+
3638+
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
3639+
IsTrailingUnderscoreInplaceVariant,
3640+
AllowsTypeRefinement
3641+
]> {
3642+
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
3643+
let arguments = (ins
3644+
Torch_NonValueTensorType:$self,
3645+
Torch_NonValueTensorType:$mask,
3646+
Torch_NonValueTensorType:$value
3647+
);
3648+
let results = (outs
3649+
Torch_NonValueTensorType:$result
3650+
);
3651+
let hasCustomAssemblyFormat = 1;
3652+
let extraClassDefinition = [{
3653+
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
3654+
return parseDefaultTorchOp(parser, result, 3, 1);
3655+
}
3656+
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
3657+
printDefaultTorchOp(printer, *this, 3, 1);
3658+
}
3659+
}];
3660+
}
3661+
36613662
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
36623663
AllowsTypeRefinement,
36633664
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,42 @@ static Value getScalarIntValue(Value input, Location loc,
162162
return nullptr;
163163
}
164164

165+
static Value getScalarFloatValue(Value input, Location loc,
166+
PatternRewriter &rewriter) {
167+
auto inputType = input.getType();
168+
if (inputType.isa<Torch::FloatType>()) {
169+
return input;
170+
}
171+
172+
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
173+
if (!inputTensorType)
174+
return nullptr;
175+
176+
Type inputDtype = inputTensorType.getOptionalDtype();
177+
if (!inputDtype ||
178+
(!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64()))
179+
return nullptr;
180+
181+
std::optional<unsigned> inputRank = getTensorRank(input);
182+
if (!inputRank || *inputRank != 0)
183+
return nullptr;
184+
185+
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
186+
auto val = valueTensorLiteralOp.getValue()
187+
.cast<DenseFPElementsAttr>()
188+
.getSplatValue<FloatAttr>()
189+
.getValueAsDouble();
190+
return rewriter.create<Torch::ConstantFloatOp>(
191+
loc, rewriter.getF64FloatAttr(val));
192+
} else if (auto primNumToTensorScalarOp =
193+
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
194+
return primNumToTensorScalarOp.getA();
195+
} else if (auto tensorFloatOp = input.getDefiningOp<AtenTensorFloatOp>()) {
196+
return tensorFloatOp.getT();
197+
}
198+
return nullptr;
199+
}
200+
165201
//===----------------------------------------------------------------------===//
166202
// MethodOp
167203
//===----------------------------------------------------------------------===//
@@ -1589,6 +1625,27 @@ OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
15891625
return nullptr;
15901626
}
15911627

1628+
//===----------------------------------------------------------------------===//
1629+
// AtenMaskedFillTensorOp
1630+
//===----------------------------------------------------------------------===//
1631+
1632+
// Fold 0d fill tensor to scalar
1633+
void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
1634+
RewritePatternSet &patterns, MLIRContext *context) {
1635+
patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) {
1636+
auto scalarIntVal =
1637+
getScalarIntValue(op.getValue(), op->getLoc(), rewriter);
1638+
auto scalarFloatVal =
1639+
getScalarFloatValue(op.getValue(), op->getLoc(), rewriter);
1640+
if (!scalarIntVal && !scalarFloatVal)
1641+
return failure();
1642+
Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal;
1643+
rewriter.replaceOpWithNewOp<AtenMaskedFillScalarOp>(
1644+
op, op.getType(), op.getSelf(), op.getMask(), scalarVal);
1645+
return failure();
1646+
});
1647+
}
1648+
15921649
//===----------------------------------------------------------------------===//
15931650
// AtenSortIntOp
15941651
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ def emit_with_mutating_variants(key, **kwargs):
300300
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
301301
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
302302
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
303-
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
304303
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
305304
"aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)",
306305
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
@@ -337,6 +336,7 @@ def emit_with_mutating_variants(key, **kwargs):
337336
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
338337
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
339338
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)
339+
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
340340

341341
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
342342
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2136,3 +2136,13 @@ func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !t
21362136
%0 = torch.aten.numel %arg0 : !torch.vtensor<[3,4],f32> -> !torch.int
21372137
return %0 : !torch.int
21382138
}
2139+
2140+
// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor$canonicalize
2141+
// CHECK-NEXT: torch.constant.float -1.000000e+09
2142+
// CHECK-NEXT: torch.aten.masked_fill.Scalar
2143+
// CHECK-NEXT: return
2144+
func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f32> {
2145+
%0 = torch.vtensor.literal(dense<-1.000000e+09> : tensor<f32>) : !torch.vtensor<[],f32>
2146+
%1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
2147+
return %1 : !torch.vtensor<[?,?],f32>
2148+
}

0 commit comments

Comments
 (0)