Skip to content

Commit 304a74a

Browse files
authored
[Relax] Add FRelaxInferLayout and TMixedPrecisionPolicy for dynamic_strided_slice (#18633)
## Why The dynamic_strided_slice operator was missing FRelaxInferLayout and TMixedPrecisionPolicy attributes, preventing it from participating in layout transformations and mixed precision optimizations. ## How - Add `TMixedPrecisionPolicy` attribute with `kFollow` policy and - Add `InferLayoutDynStridedSlice` function that falls back to initial layout (since begin/end/strides are dynamic tensors that cannot be transformed at compile time)
1 parent e46c061 commit 304a74a

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

src/relax/op/tensor/index.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,32 @@ StructInfo InferStructInfoDynStridedSlice(const Call& call, const BlockBuilder&
549549
return TensorStructInfo(data_sinfo->dtype, n_axis, data_sinfo->vdevice);
550550
}
551551

552-
// TODO(tvm-team): Register FRelaxInferLayout, TMixedPrecisionPolicy
552+
InferLayoutOutput InferLayoutDynStridedSlice(
553+
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
554+
const VarLayoutMap& var_layout_map) {
555+
ICHECK(NoDesiredLayout(call, desired_layouts));
556+
557+
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
558+
CHECK(tensor_sinfo) << "Invalid Call";
559+
CHECK(!tensor_sinfo->IsUnknownNdim()) << "Layout inference only supports known dimensionality, "
560+
<< "but expression " << call << " has argument "
561+
<< call->args[0] << " of unknown dimensionality.";
562+
int ndim = tensor_sinfo->ndim;
563+
// Since begin/end/strides are dynamic tensors, we cannot transform
564+
// them at compile time. Fall back to the initial layout.
565+
LayoutDecision initial = LayoutDecision(InitialLayout(ndim));
566+
return InferLayoutOutput({initial}, {initial}, Attrs());
567+
}
568+
553569
TVM_REGISTER_OP("relax.dynamic_strided_slice")
554570
.set_num_inputs(4)
555571
.add_argument("x", "Tensor", "The source tensor to be sliced.")
556572
.add_argument("begin", "Tensor", "The indices to begin with in the slicing.")
557573
.add_argument("end", "Tensor", "Indices indicating end of the slice.")
558574
.add_argument("strides", "Tensor", "The stride values.")
559575
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDynStridedSlice)
576+
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutDynStridedSlice)
577+
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
560578
.set_attr<Bool>("FPurity", Bool(true));
561579

562580
} // namespace relax

tests/python/relax/test_transform_convert_layout.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5231,5 +5231,57 @@ def main(
52315231
verify(Input, Expected)
52325232

52335233

5234+
def test_conv2d_dynamic_strided_slice():
5235+
@I.ir_module
5236+
class Input:
5237+
@R.function
5238+
def main(
5239+
x: R.Tensor((2, 3, 28, 28), "float32"),
5240+
w: R.Tensor((4, 3, 3, 3), "float32"),
5241+
begin: R.Tensor((4,), "int64"),
5242+
end: R.Tensor((4,), "int64"),
5243+
strides: R.Tensor((4,), "int64"),
5244+
) -> R.Tensor(None, "float32", ndim=4):
5245+
with R.dataflow():
5246+
gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
5247+
gv2 = R.dynamic_strided_slice(gv, begin, end, strides)
5248+
R.output(gv2)
5249+
return gv2
5250+
5251+
@I.ir_module
5252+
class Expected:
5253+
@R.function
5254+
def main(
5255+
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
5256+
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
5257+
begin: R.Tensor((4,), dtype="int64"),
5258+
end: R.Tensor((4,), dtype="int64"),
5259+
strides: R.Tensor((4,), dtype="int64"),
5260+
) -> R.Tensor(None, dtype="float32", ndim=4):
5261+
with R.dataflow():
5262+
lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
5263+
lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1])
5264+
gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
5265+
lv,
5266+
lv1,
5267+
strides=[1, 1],
5268+
padding=[0, 0, 0, 0],
5269+
dilation=[1, 1],
5270+
groups=1,
5271+
data_layout="NHWC",
5272+
kernel_layout="OHWI",
5273+
out_layout="NHWC",
5274+
out_dtype="float32",
5275+
)
5276+
lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
5277+
gv, axes=[0, 3, 1, 2]
5278+
)
5279+
gv2 = R.dynamic_strided_slice(lv2, begin, end, strides)
5280+
R.output(gv2)
5281+
return gv2
5282+
5283+
verify(Input, Expected)
5284+
5285+
52345286
if __name__ == "__main__":
52355287
tvm.testing.main()

tests/python/relax/test_transform_to_mixed_precision.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,5 +1064,58 @@ def tir_identity(
10641064
tvm.ir.assert_structural_equal(Expected, After)
10651065

10661066

1067+
def test_dynamic_strided_slice():
1068+
@I.ir_module
1069+
class Input:
1070+
@R.function
1071+
def main(
1072+
x: R.Tensor((2, 3, 28, 28), "float32"),
1073+
w: R.Tensor((4, 3, 3, 3), "float32"),
1074+
begin: R.Tensor((4,), "int64"),
1075+
end: R.Tensor((4,), "int64"),
1076+
strides: R.Tensor((4,), "int64"),
1077+
) -> R.Tensor(None, "float32", ndim=4):
1078+
with R.dataflow():
1079+
lv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
1080+
gv = R.dynamic_strided_slice(lv, begin, end, strides)
1081+
R.output(gv)
1082+
return gv
1083+
1084+
@I.ir_module
1085+
class Expected:
1086+
@R.function
1087+
def main(
1088+
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
1089+
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
1090+
begin: R.Tensor((4,), dtype="int64"),
1091+
end: R.Tensor((4,), dtype="int64"),
1092+
strides: R.Tensor((4,), dtype="int64"),
1093+
) -> R.Tensor(None, dtype="float32", ndim=4):
1094+
with R.dataflow():
1095+
lv: R.Tensor((2, 3, 28, 28), dtype="float16") = R.astype(x, dtype="float16")
1096+
lv1: R.Tensor((4, 3, 3, 3), dtype="float16") = R.astype(w, dtype="float16")
1097+
lv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.conv2d(
1098+
lv,
1099+
lv1,
1100+
strides=[1, 1],
1101+
padding=[0, 0, 0, 0],
1102+
dilation=[1, 1],
1103+
groups=1,
1104+
data_layout="NCHW",
1105+
kernel_layout="OIHW",
1106+
out_layout="NCHW",
1107+
out_dtype="float32",
1108+
)
1109+
lv3: R.Tensor((2, 4, 26, 26), dtype="float16") = R.astype(lv2, dtype="float16")
1110+
lv4: R.Tensor((2, 4, 26, 26), dtype="float32") = R.astype(lv3, dtype="float32")
1111+
gv: R.Tensor(None, dtype="float32", ndim=4) = R.dynamic_strided_slice(
1112+
lv4, begin, end, strides
1113+
)
1114+
R.output(gv)
1115+
return gv
1116+
1117+
_assert_test(Input, Expected)
1118+
1119+
10671120
if __name__ == "__main__":
10681121
tvm.testing.main()

0 commit comments

Comments
 (0)