Skip to content

Commit 29062f8

Browse files
authored
support lowering for aten.log1p (#93)
Lower aten.log1p op to tcp.log(tcp.add(input, 1.0)) To test: `bazel test //...` (in docker)
1 parent 57d5e00 commit 29062f8

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

lib/Conversion/TorchToTcp/Elementwise.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,37 @@ class ConvertAtenSqrtOp : public OpConversionPattern<AtenSqrtOp> {
478478
}
479479
};
480480

481+
class ConvertAtenLog1pOp : public OpConversionPattern<AtenLog1pOp> {
482+
public:
483+
using OpConversionPattern::OpConversionPattern;
484+
485+
LogicalResult
486+
matchAndRewrite(AtenLog1pOp op, OpAdaptor adaptor,
487+
ConversionPatternRewriter &rewriter) const override {
488+
Value input = adaptor.getSelf();
489+
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();
490+
491+
if (!inputType)
492+
return rewriter.notifyMatchFailure(
493+
op, "Only Ranked Tensor types are supported in TCP");
494+
495+
auto elementType = inputType.getElementType();
496+
if (!isa<mlir::FloatType>(elementType))
497+
return rewriter.notifyMatchFailure(
498+
op, "Only floating-point datatype is supported");
499+
500+
auto constOp = torch_to_tcp::getConstTensor<float>(
501+
rewriter, op, llvm::ArrayRef((float)1.0), {})
502+
.value();
503+
constOp = torch_to_tcp::broadcast0DOr1DToNDAndMatchShape(
504+
rewriter, constOp, input, elementType);
505+
auto addOp =
506+
rewriter.create<tcp::AddOp>(op.getLoc(), inputType, input, constOp);
507+
rewriter.replaceOpWithNewOp<tcp::LogOp>(op, inputType, addOp);
508+
return success();
509+
}
510+
};
511+
481512
template <typename AtenOpT, typename TcpOpT>
482513
class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
483514
public:
@@ -694,6 +725,7 @@ void torch_to_tcp::populateElementwisePatternsAndLegality(
694725
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenBatchNormOp);
695726
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenAtan2Op);
696727
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenSqrtOp);
728+
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenLog1pOp);
697729
#undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN
698730

699731
#define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \

test/AotCompile/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ AOT_TEST_SUITE = [
2121
("tanh", False),
2222
("clamp", False),
2323
("relu", False),
24+
("log1p", False),
2425
("round_even", False),
2526
("sqrt_float", False),
2627
("sqrt_int", False),

test/AotCompile/model_loader_lib.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
263263
return TorchLoaderOutput(model=Relu(), inputs=(x,), dynamic_shapes=dynamic_shapes)
264264

265265

266+
def log1p_loader() -> TorchLoaderOutput:
267+
class Log1p(torch.nn.Module):
268+
def __init__(self):
269+
super().__init__()
270+
271+
def forward(self, x: torch.Tensor) -> torch.Tensor:
272+
return torch.log1p(x)
273+
274+
# Sample inputs
275+
x = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
276+
277+
# Dynamic dim constraints
278+
batch = Dim("batch")
279+
dynamic_shapes = {"x": {0: batch}}
280+
281+
return TorchLoaderOutput(model=Log1p(), inputs=(x,), dynamic_shapes=dynamic_shapes)
282+
283+
266284
def round_even_loader() -> TorchLoaderOutput:
267285
class RoundEven(torch.nn.Module):
268286
def __init__(self):

test/Conversion/TorchToTcp/elementwise.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,22 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtenso
762762
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
763763
return %0 : !torch.vtensor<[?,?],ui8>
764764
}
765+
766+
// -----
767+
768+
// CHECK-LABEL: func.func @torch.aten.log1p(
769+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
770+
// CHECK-DAG: %[[TO_BUILTIN0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4,19,2],f32> -> tensor<?x4x19x2xf32>
771+
// CHECK: %[[CONST:.*]] = tcp.const {value = dense<1.000000e+00> : tensor<f32>} : tensor<f32>
772+
// CHECK: %[[EXPAND_SHAPE:.*]] = tensor.expand_shape %[[CONST]] [] output_shape [1, 1, 1, 1] : tensor<f32> into tensor<1x1x1x1xf32>
773+
// CHECK: %[[CONST0:.*]] = arith.constant 0 : index
774+
// CHECK: %[[DIM0:.*]] = tensor.dim %[[TO_BUILTIN0]], %[[CONST0]] : tensor<?x4x19x2xf32>
775+
// CHECK: %[[BROADCAST:.*]] = tcp.broadcast %[[EXPAND_SHAPE]], %[[DIM0]]
776+
// CHECK: %[[ADD:.*]] = tcp.add %[[TO_BUILTIN0]], %[[BROADCAST]] : tensor<?x4x19x2xf32>, tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
777+
// CHECK: %[[LOG:.*]] = tcp.log %[[ADD]] : tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
778+
// CHECK: %[[FROM_BUILTIN:.*]] = torch_c.from_builtin_tensor %[[LOG]] : tensor<?x4x19x2xf32> -> !torch.vtensor<[?,4,19,2],f32>
779+
// CHECK: return %[[FROM_BUILTIN]] : !torch.vtensor<[?,4,19,2],f32>
780+
func.func @torch.aten.log1p(%arg0: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
781+
%1 = torch.aten.log1p %arg0 : !torch.vtensor<[?,4,19,2],f32> -> !torch.vtensor<[?,4,19,2],f32>
782+
return %1 : !torch.vtensor<[?,4,19,2],f32>
783+
}

0 commit comments

Comments
 (0)