Skip to content

Commit ee58041

Browse files
authored
tcp.custom_op support for torch.aten.arange with dynamic input (#100)
tcp.custom_op support for` torch.aten.arange` with dynamic input. Static case will be handled through TOSA dialect. To test: (in docker) `bazel test //...`
1 parent cac69d8 commit ee58041

File tree

4 files changed

+93
-1
lines changed

4 files changed

+93
-1
lines changed

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,51 @@ class ConvertAtenSliceScatterOp
345345
}
346346
};
347347

348+
class ConvertAtenArangeStartStepOp
349+
: public OpConversionPattern<AtenArangeStartStepOp> {
350+
using OpConversionPattern::OpConversionPattern;
351+
352+
LogicalResult
353+
matchAndRewrite(AtenArangeStartStepOp op, OpAdaptor adaptor,
354+
ConversionPatternRewriter &rewriter) const override {
355+
356+
// At this point all tensors should have value semantics, and hence the
357+
// `layout` check can be ignored.
358+
359+
// The pin_memory should be either `False` or `none`.
360+
bool pinMemory;
361+
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
362+
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
363+
pinMemory)) {
364+
return rewriter.notifyMatchFailure(
365+
op, "unimplemented: pin_memory must be either None or false");
366+
}
367+
368+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
369+
getTypeConverter()};
370+
bool allStatic = true;
371+
// trt-mlir takes F64Attr, so we need to convert const int to fp attr
372+
if (!helper.tryConvertConstToFloatAttr("start", op.getStart())) {
373+
allStatic = false;
374+
helper.addOperand("start", adaptor.getStart());
375+
}
376+
if (!helper.tryConvertConstToFloatAttr("end", op.getEnd())) {
377+
allStatic = false;
378+
helper.addOperand("end", adaptor.getEnd());
379+
}
380+
if (!helper.tryConvertConstToFloatAttr("step", op.getStep())) {
381+
allStatic = false;
382+
helper.addOperand("step", adaptor.getStep());
383+
}
384+
// static start, end, and step case will be handled through TOSA dialect
385+
if (allStatic)
386+
return rewriter.notifyMatchFailure(op,
387+
"only non-constant values supported");
388+
389+
return helper.replace();
390+
}
391+
};
392+
348393
} // namespace
349394

350395
void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
@@ -365,8 +410,10 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
365410
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp);
366411
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp);
367412
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSliceScatterOp);
368-
// AtenViewOp can still live after torch-to-tcp conversion
413+
// Following ops can still live after torch-to-tcp conversion
369414
patterns.add<ConvertAtenViewOp>(typeConverter, patterns.getContext());
415+
patterns.add<ConvertAtenArangeStartStepOp>(typeConverter,
416+
patterns.getContext());
370417
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
371418

372419
// Torch -> TOSA doesn't handle transposed convolutions; map them to

lib/Conversion/TorchToTcp/Utils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,29 @@ void TorchToTcpCustomOpConversionHelper::addFloatAttr(std::string attrName,
529529
rewriter.getNamedAttr(attrName, rewriter.getF64FloatAttr(constVal)));
530530
}
531531

532+
bool TorchToTcpCustomOpConversionHelper::tryConvertConstToFloatAttr(
533+
std::string attrName, Value value) {
534+
if (conversionResult.failed())
535+
return false;
536+
537+
double constFPVal;
538+
if (matchPattern(value, torch::Torch::m_TorchConstantFloat(&constFPVal))) {
539+
attrs.push_back(
540+
rewriter.getNamedAttr(attrName, rewriter.getF64FloatAttr(constFPVal)));
541+
return true;
542+
}
543+
544+
// convert constant int to fp if possible
545+
int64_t constIntVal;
546+
if (matchPattern(value, torch::Torch::m_TorchConstantInt(&constIntVal))) {
547+
attrs.push_back(rewriter.getNamedAttr(
548+
attrName, rewriter.getF64FloatAttr(static_cast<double>(constIntVal))));
549+
return true;
550+
}
551+
552+
return false;
553+
}
554+
532555
void TorchToTcpCustomOpConversionHelper::addListOfIntsAttr(std::string attrName,
533556
Value value) {
534557
if (conversionResult.failed())

lib/Conversion/TorchToTcp/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class TorchToTcpCustomOpConversionHelper {
167167
// Add value as a named float attribute
168168
void addFloatAttr(std::string attrName, Value value);
169169

170+
// Try to convert a const value to a float attribute.
171+
bool tryConvertConstToFloatAttr(std::string attrName, Value value);
172+
170173
// Add value as a named list of integers attribute
171174
void addListOfIntsAttr(std::string attrName, Value value);
172175

test/Conversion/TorchToTcp/tcp_custom_ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,22 @@ func.func @torch.aten.slice_scatter(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !to
320320
%0 = torch.aten.slice_scatter %arg0, %arg1, %dim, %start, %end, %step : !torch.vtensor<[1,3],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],f32>
321321
return %0 : !torch.vtensor<[1,3],f32>
322322
}
323+
324+
// -----
325+
326+
// CHECK-LABEL: func.func @torch.aten.arange.start_step(
327+
// CHECK-SAME: %[[ARG0:.*]]: !torch.int) -> !torch.vtensor<[?],si32> {
328+
// CHECK: %[[IN:.*]] = torch_c.to_i64 %[[ARG0]]
329+
// CHECK: %[[OUT:.*]] = tcp.custom_op("torch.aten.arange.start_step") %[[IN]] {start = 0.000000e+00 : f64, step = 1.000000e+00 : f64, torch_operand_names = ["end"]} : i64 -> tensor<?xi32>
330+
// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[OUT]] : tensor<?xi32> -> !torch.vtensor<[?],si32>
331+
// CHECK: return %[[RET]] : !torch.vtensor<[?],si32>
332+
func.func @torch.aten.arange.start_step(%arg0: !torch.int) -> !torch.vtensor<[?],si32> {
333+
%false = torch.constant.bool false
334+
%none = torch.constant.none
335+
%cpu = torch.constant.device "cpu"
336+
%int0 = torch.constant.int 0
337+
%int1 = torch.constant.int 1
338+
%int3 = torch.constant.int 3
339+
%1 = torch.aten.arange.start_step %int0, %arg0, %int1, %int3, %none, %cpu, %false : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[?],si32>
340+
return %1 : !torch.vtensor<[?],si32>
341+
}

0 commit comments

Comments
 (0)