Skip to content

Commit 6ca39bc

Browse files
Merge branch 'main' into tosa-clamp-f16-bf16
2 parents 83f41f9 + b1053f8 commit 6ca39bc

File tree

29 files changed

+1679
-575
lines changed

29 files changed

+1679
-575
lines changed

externals/llvm-project

Submodule llvm-project updated 19594 files

externals/stablehlo

Submodule stablehlo updated 80 files

include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
6262

6363
let builders = [
6464
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
65-
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)>
65+
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive), [{
66+
build($_builder, $_state, TypeRange(outputs), inputs, outputs, dimension, inclusive);
67+
}]>
6668
];
6769

6870
let results = (outs Variadic<AnyRankedTensor>:$results);
@@ -267,7 +269,9 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
267269
);
268270

269271
let builders = [
270-
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
272+
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs), [{
273+
build($_builder, $_state, TypeRange(outputs), inputs, outputs);
274+
}]>
271275
];
272276

273277
let results = (outs Variadic<AnyRankedTensor>:$result);

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
86688668
}];
86698669
}
86708670

8671+
def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [
8672+
AllowsTypeRefinement,
8673+
HasValueSemantics,
8674+
ReadOnly
8675+
]> {
8676+
let summary = "Generated op for `aten::channel_shuffle : (Tensor, int) -> (Tensor)`";
8677+
let arguments = (ins
8678+
AnyTorchTensorType:$self,
8679+
Torch_IntType:$groups
8680+
);
8681+
let results = (outs
8682+
AnyTorchOptionalTensorType:$result
8683+
);
8684+
let hasCustomAssemblyFormat = 1;
8685+
let extraClassDefinition = [{
8686+
ParseResult AtenChannelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
8687+
return parseDefaultTorchOp(parser, result, 2, 1);
8688+
}
8689+
void AtenChannelShuffleOp::print(OpAsmPrinter &printer) {
8690+
printDefaultTorchOp(printer, *this, 2, 1);
8691+
}
8692+
}];
8693+
}
8694+
86718695
def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
86728696
AllowsTypeRefinement,
86738697
ReadOnly
@@ -14350,7 +14374,7 @@ def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata",
1435014374
printDefaultTorchOp(printer, *this, 6, 0);
1435114375
}
1435214376
}];
14353-
let hasFolder = 1;
14377+
let hasCanonicalizer = 1;
1435414378
}
1435514379

1435614380
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,74 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16061606
/* cudnn enabled */ boolFalse);
16071607
return success();
16081608
});
1609+
patterns.onOp(
1610+
"MeanVarianceNormalization", 13,
1611+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1612+
Torch::ValueTensorType resultType;
1613+
Value input;
1614+
SmallVector<int64_t> axes;
1615+
1616+
if (binder.tensorOperand(input) ||
1617+
binder.s64IntegerArrayAttr(axes, "axes",
1618+
llvm::SmallVector<int64_t>({0, 2, 3})) ||
1619+
binder.tensorResultType(resultType)) {
1620+
return failure();
1621+
}
1622+
if (!resultType.hasSizes() || !resultType.hasDtype()) {
1623+
return failure();
1624+
}
1625+
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
1626+
if (!inputTy || !inputTy.hasSizes()) {
1627+
return failure();
1628+
}
1629+
int64_t inputRank = inputTy.getSizes().size();
1630+
1631+
Location loc = binder.getLoc();
1632+
Value keepDim = rewriter.create<Torch::ConstantBoolOp>(loc, true);
1633+
Value unBiased = rewriter.create<Torch::ConstantBoolOp>(loc, false);
1634+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1635+
1636+
ArrayRef<int64_t> output_shape = resultType.getSizes();
1637+
SmallVector<int64_t> reduced_shape(output_shape);
1638+
1639+
for (int64_t i : axes) {
1640+
int64_t dim = Torch::toPositiveDim(i, inputRank);
1641+
if (!Torch::isValidDim(dim, inputRank)) {
1642+
return failure();
1643+
}
1644+
reduced_shape[dim] = 1;
1645+
}
1646+
Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get(
1647+
resultType.getContext(), reduced_shape, resultType.getDtype());
1648+
SmallVector<Value> cstAxes;
1649+
for (int64_t i : axes) {
1650+
cstAxes.push_back(rewriter.create<Torch::ConstantIntOp>(
1651+
loc, rewriter.getI64IntegerAttr(i)));
1652+
}
1653+
Value axes_list = rewriter.create<Torch::PrimListConstructOp>(
1654+
loc,
1655+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
1656+
cstAxes);
1657+
Value mean = rewriter.create<Torch::AtenMeanDimOp>(
1658+
loc, reducedOutTy, input, axes_list, keepDim, none);
1659+
Value variance = rewriter.create<Torch::AtenVarDimOp>(
1660+
loc, reducedOutTy, input, axes_list, unBiased, keepDim);
1661+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
1662+
loc, rewriter.getI64IntegerAttr(1));
1663+
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
1664+
loc, rewriter.getF64FloatAttr(1e-9));
1665+
variance = rewriter.create<Torch::AtenAddScalarOp>(
1666+
loc, reducedOutTy, variance, cstEps, cstOne);
1667+
Value sqrtVar =
1668+
rewriter.create<Torch::AtenSqrtOp>(loc, reducedOutTy, variance);
1669+
Value inputMinusMean = rewriter.create<Torch::AtenSubTensorOp>(
1670+
loc, resultType, input, mean, cstOne);
1671+
Value meanVarNorm = rewriter.create<Torch::AtenDivTensorOp>(
1672+
loc, resultType, inputMinusMean, sqrtVar);
1673+
1674+
rewriter.replaceOp(binder.op, meanVarNorm);
1675+
return success();
1676+
});
16091677
patterns.onOp(
16101678
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
16111679
Torch::ValueTensorType resultType;

0 commit comments

Comments
 (0)