Skip to content

Commit 57d5e00

Browse files
authored
update custom op conversions (#92)
- Fix a minor issue with broadcast op (fold the broadcast if axes attr is empty) - Add following ops to tcp.custom_op: --`aten.sort` --`aten.cumsum` --`aten.min.dim` --`aten.view`(dynamic shape only) --`aten.topk` To test: `bazel test //...` (in docker)
1 parent 7b53fe4 commit 57d5e00

File tree

4 files changed

+242
-1
lines changed

4 files changed

+242
-1
lines changed

lib/Conversion/TorchToTcp/Misc.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,15 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern<AtenOpT> {
127127
}
128128
}
129129

130+
// fold the broadcast if no axes are found
131+
if (axes.size() == 0) {
132+
rewriter.replaceOp(op, input);
133+
return success();
134+
}
130135
RankedTensorType resultType =
131136
OpConversionPattern<AtenOpT>::getTypeConverter()
132137
->convertType(op->getResult(0).getType())
133138
.template cast<RankedTensorType>();
134-
135139
auto axesAttr = rewriter.getI64ArrayAttr(axes);
136140
rewriter.replaceOpWithNewOp<tcp::BroadcastOp>(op, resultType, input,
137141
resultShape, axesAttr);

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
#include "PopulatePatterns.h"
1616
#include "Utils.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1719
#include "torch-mlir/Conversion/Utils/Utils.h"
1820
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1921
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
22+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
2023

2124
#include "llvm/ADT/StringSet.h"
2225

@@ -211,6 +214,145 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
211214
}
212215
};
213216

217+
class ConvertAtenTopkOp : public OpConversionPattern<AtenTopkOp> {
218+
public:
219+
using OpConversionPattern::OpConversionPattern;
220+
221+
LogicalResult
222+
matchAndRewrite(AtenTopkOp op, OpAdaptor adaptor,
223+
ConversionPatternRewriter &rewriter) const override {
224+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
225+
getTypeConverter()};
226+
helper.addOperand("self", adaptor.getSelf());
227+
228+
helper.addIntAttr("k", op.getK());
229+
helper.addIntAttr("dim", op.getDim());
230+
helper.addBoolAttr("largest", op.getLargest());
231+
helper.addBoolAttr("sorted", op.getSorted());
232+
233+
return helper.replace();
234+
}
235+
};
236+
237+
class ConvertAtenSortOp : public OpConversionPattern<AtenSortOp> {
238+
public:
239+
using OpConversionPattern::OpConversionPattern;
240+
241+
LogicalResult
242+
matchAndRewrite(AtenSortOp op, OpAdaptor adaptor,
243+
ConversionPatternRewriter &rewriter) const override {
244+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
245+
getTypeConverter()};
246+
helper.addOperand("self", adaptor.getSelf());
247+
248+
helper.addIntAttr("dim", op.getDim());
249+
helper.addBoolAttr("descending", op.getDescending());
250+
251+
return helper.replace();
252+
}
253+
};
254+
255+
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
256+
public:
257+
using OpConversionPattern::OpConversionPattern;
258+
259+
LogicalResult
260+
matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor,
261+
ConversionPatternRewriter &rewriter) const override {
262+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
263+
getTypeConverter()};
264+
helper.addOperand("self", adaptor.getSelf());
265+
266+
helper.addIntAttr("dim", op.getDim());
267+
if (!isa<Torch::ConstantNoneOp>(op.getDtype().getDefiningOp()))
268+
return rewriter.notifyMatchFailure(op, "Unsupported dtype argument");
269+
270+
return helper.replace();
271+
}
272+
};
273+
274+
class ConvertAtenMinDimOp : public OpConversionPattern<AtenMinDimOp> {
275+
public:
276+
using OpConversionPattern::OpConversionPattern;
277+
278+
LogicalResult
279+
matchAndRewrite(AtenMinDimOp op, OpAdaptor adaptor,
280+
ConversionPatternRewriter &rewriter) const override {
281+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
282+
getTypeConverter()};
283+
helper.addOperand("self", adaptor.getSelf());
284+
285+
helper.addIntAttr("dim", op.getDim());
286+
helper.addBoolAttr("keepdim", op.getKeepdim());
287+
288+
return helper.replace();
289+
}
290+
};
291+
292+
class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
293+
public:
294+
using OpConversionPattern::OpConversionPattern;
295+
296+
LogicalResult
297+
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
298+
ConversionPatternRewriter &rewriter) const override {
299+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
300+
getTypeConverter()};
301+
Value self = adaptor.getSelf();
302+
auto srcType = self.getType().cast<RankedTensorType>();
303+
auto resultType =
304+
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
305+
306+
SmallVector<int64_t> size;
307+
// static shape will be handled through TOSA dialect
308+
if (matchPattern(op.getSize(), m_TorchListOfConstantInts(size)) &&
309+
srcType.hasStaticShape() && resultType.hasStaticShape())
310+
return rewriter.notifyMatchFailure(op, "only dynamic shape is supported");
311+
312+
helper.addOperand("self", self);
313+
Operation *primListOp = op.getSize().getDefiningOp();
314+
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(primListOp);
315+
if (!listConstruct) {
316+
return rewriter.notifyMatchFailure(
317+
op, "Size must come from PrimListConstructOp");
318+
}
319+
int idx = 0;
320+
for (Value value : listConstruct.getElements()) {
321+
int64_t dimSize;
322+
if (matchPattern(value, m_TorchConstantInt(&dimSize))) {
323+
size.push_back(dimSize);
324+
} else {
325+
size.push_back(ShapedType::kDynamic);
326+
// dynamic shape should follow pattern:
327+
// %dim_32 = tensor.dim %arg1, %c0 : tensor<?x2736x16xf32>
328+
// %1 = arith.index_cast %dim_32 : index to i64
329+
// %2 = torch_c.from_i64 %1
330+
// %3 = torch.prim.ListConstruct %2 ...
331+
if (!isa<TorchConversion::FromI64Op>(value.getDefiningOp()))
332+
return rewriter.notifyMatchFailure(
333+
op, "dynamic dim size should come from FromI64Op");
334+
auto conversionOp =
335+
dyn_cast<TorchConversion::FromI64Op>(value.getDefiningOp());
336+
if (!isa<arith::IndexCastOp>(conversionOp.getOperand().getDefiningOp()))
337+
return rewriter.notifyMatchFailure(
338+
op, "dynamic dim size should come from IndexCastOp");
339+
auto indexCastOp = dyn_cast<arith::IndexCastOp>(
340+
conversionOp.getOperand().getDefiningOp());
341+
if (!isa<tensor::DimOp>(indexCastOp.getIn().getDefiningOp()))
342+
return rewriter.notifyMatchFailure(
343+
op, "dynamic dim size should come from DimOp");
344+
auto dimOp =
345+
dyn_cast<tensor::DimOp>(indexCastOp.getIn().getDefiningOp());
346+
helper.addOperand("idx_" + std::to_string(idx), dimOp);
347+
}
348+
idx++;
349+
}
350+
helper.addDenseIntArrayAttr("size", size);
351+
352+
return helper.replace();
353+
}
354+
};
355+
214356
} // namespace
215357

216358
void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
@@ -227,6 +369,12 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
227369
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(
228370
AtenFakeQuantizePerTensorAffineTensorQparamsOp);
229371
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerChannelAffineOp);
372+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenTopkOp);
373+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSortOp);
374+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp);
375+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp);
376+
// AtenViewOp can still live after torch-to-tcp conversion
377+
patterns.add<ConvertAtenViewOp>(typeConverter, patterns.getContext());
230378
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
231379

232380
// Torch -> TOSA doesn't handle transposed convolutions; map them to

lib/InitAll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
1919
#include "mlir/IR/Dialect.h"
2020
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
21+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
2122

2223
void mlir::tcp::registerAllDialects(mlir::DialectRegistry &registry) {
2324
registry.insert<tcp::TcpDialect>();
2425
registry.insert<torch::Torch::TorchDialect>();
26+
registry.insert<torch::TorchConversion::TorchConversionDialect>();
2527
mlir::func::registerInlinerExtension(registry);
2628
mlir::tcp::registerTilingInterfaceExternalModels(registry);
2729
}

test/Conversion/TorchToTcp/tcp_custom_ops.mlir

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,90 @@ func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.
254254
%output = torch.aten.fake_quantize_per_channel_affine %input, %scale, %zero_point, %int1, %int0, %int255 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
255255
return %output : !torch.vtensor<[1,3,32,32],f32>
256256
}
257+
258+
// -----
259+
260+
// CHECK-LABEL: func.func @torch.aten.topk(
261+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,80],f32> {
262+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
263+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.topk") %[[T0]] {dim = -1 : i64, k = 80 : i64, largest = true, sorted = true, torch_operand_names = ["self"]} :
264+
// CHECK-SAME: tensor<?x2304xf32> -> tensor<?x80xf32>, tensor<?x80xi64>
265+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x80xf32> -> !torch.vtensor<[?,80],f32>
266+
// CHECK: return %[[RES]] : !torch.vtensor<[?,80],f32>
267+
func.func @torch.aten.topk(%input: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,80],f32> {
268+
%int-1 = torch.constant.int -1
269+
%int80 = torch.constant.int 80
270+
%true = torch.constant.bool true
271+
%output0, %output1 = torch.aten.topk %input, %int80, %int-1, %true, %true : !torch.vtensor<[?,2304],f32>, !torch.int, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[?,80],f32>, !torch.vtensor<[?,80],si64>
272+
return %output0 : !torch.vtensor<[?,80],f32>
273+
}
274+
275+
// -----
276+
277+
// CHECK-LABEL: func.func @torch.aten.sort(
278+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
279+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
280+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.sort") %[[T0]] {descending = true, dim = -1 : i64, torch_operand_names = ["self"]} :
281+
// CHECK-SAME: tensor<?x2304xf32> -> tensor<?x2304xf32>, tensor<?x2304xi64>
282+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x2304xf32> -> !torch.vtensor<[?,2304],f32>
283+
// CHECK: return %[[RES]] : !torch.vtensor<[?,2304],f32>
284+
func.func @torch.aten.sort(%input: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
285+
%int-1 = torch.constant.int -1
286+
%true = torch.constant.bool true
287+
%output0, %output1 = torch.aten.sort %input, %int-1, %true : !torch.vtensor<[?,2304],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,2304],f32>, !torch.vtensor<[?,2304],si64>
288+
return %output0 : !torch.vtensor<[?,2304],f32>
289+
}
290+
291+
// -----
292+
293+
// CHECK-LABEL: func.func @torch.aten.cumsum(
294+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
295+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?],si32> -> tensor<?xi32>
296+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.cumsum") %[[T0]] {dim = 0 : i64, torch_operand_names = ["self"]} : tensor<?xi32> -> tensor<?xi64>
297+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<?xi64> -> !torch.vtensor<[?],si64>
298+
// CHECK: return %[[RES]] : !torch.vtensor<[?],si64>
299+
func.func @torch.aten.cumsum(%input: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
300+
%int0 = torch.constant.int 0
301+
%none = torch.constant.none
302+
%1 = torch.aten.cumsum %input, %int0, %none : !torch.vtensor<[?],si32>, !torch.int, !torch.none -> !torch.vtensor<[?],si64>
303+
return %1 : !torch.vtensor<[?],si64>
304+
}
305+
306+
// -----
307+
308+
// CHECK-LABEL: func.func @torch.aten.min.dim(
309+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
310+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,80],f32> -> tensor<?x80xf32>
311+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.min.dim") %[[T0]] {dim = 1 : i64, keepdim = false, torch_operand_names = ["self"]} :
312+
// CHECK-SAME: tensor<?x80xf32> -> tensor<?xf32>, tensor<?xi64>
313+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
314+
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
315+
func.func @torch.aten.min.dim(%input: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
316+
%int1 = torch.constant.int 1
317+
%false = torch.constant.bool false
318+
%output0, %output1 = torch.aten.min.dim %input, %int1, %false : !torch.vtensor<[?,80],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
319+
return %output0 : !torch.vtensor<[?],f32>
320+
}
321+
322+
// -----
323+
324+
// CHECK-LABEL: func.func @torch.aten.view_dynamic_shape(
325+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,384,16],f32>, %[[ARG1:.*]]: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
326+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
327+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,384,16],f32> -> tensor<?x384x16xf32>
328+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x2736x16xf32>
329+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.view") %[[T0]], %[[DIM]] {size = array<i64: -9223372036854775808, 24, 16, 16>, torch_operand_names = ["self", "idx_0"]} :
330+
// CHECK-SAME: tensor<?x384x16xf32>, index -> tensor<?x24x16x16xf32>
331+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x24x16x16xf32> -> !torch.vtensor<[?,24,16,16],f32>
332+
// CHECK: return %[[RES]] : !torch.vtensor<[?,24,16,16],f32>
333+
func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>, %arg1: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
334+
%c0 = arith.constant 0 : index
335+
%int24 = torch.constant.int 24
336+
%int16 = torch.constant.int 16
337+
%dim_32 = tensor.dim %arg1, %c0 : tensor<?x2736x16xf32>
338+
%1 = arith.index_cast %dim_32 : index to i64
339+
%2 = torch_c.from_i64 %1
340+
%3 = torch.prim.ListConstruct %2, %int24, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
341+
%4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list<int> -> !torch.vtensor<[?,24,16,16],f32>
342+
return %4 : !torch.vtensor<[?,24,16,16],f32>
343+
}

0 commit comments

Comments
 (0)