Skip to content

Commit d8c973e

Browse files
[RELAX][LAYOUT] Support for dynamic layout specification (#18675)
This allows user defined callback to specify layouts dynamically based on call description. Helpful to alter layouts based on the operator shapes or attributes. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 26c6b13 commit d8c973e

File tree

4 files changed

+124
-14
lines changed

4 files changed

+124
-14
lines changed

include/tvm/relax/transform.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ using PassContext = tvm::transform::PassContext;
4141
using Function = tvm::relax::Function;
4242
using DataflowBlock = tvm::relax::DataflowBlock;
4343
using tvm::transform::CreateModulePass;
44+
using LayoutCb = ffi::TypedFunction<ffi::Map<ffi::String, ffi::Array<ffi::String>>(Call)>;
4445

4546
/*!
4647
* \brief Create a function pass.
@@ -606,10 +607,12 @@ TVM_DLL Pass AlterOpImpl(
606607
/*!
607608
* \brief Layout conversion pass.
608609
* \param desired_layouts The desired layouts for some operators.
610+
* \param layout_cb custom call back to define layouts dynamically.
609611
* \return The Pass.
610612
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first.
611613
*/
612-
TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts);
614+
TVM_DLL Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
615+
LayoutCb layout_cb);
613616

614617
/*!
615618
* \brief A pass that converts consecutive dataflow operations

python/tvm/relax/transform/transform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,7 +1367,10 @@ def AlterOpImpl(
13671367
) # type: ignore
13681368

13691369

1370-
def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass:
1370+
def ConvertLayout(
1371+
desired_layouts: Dict[str, List[str]],
1372+
layout_cb: Callable = None,
1373+
) -> tvm.ir.transform.Pass:
13711374
"""Automatic layout conversion pass.
13721375
13731376
Parameters
@@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pas
13771380
of the desired feature map, weight and output. For example, if we want to convert the
13781381
layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be
13791382
``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
1383+
layout_cb : Callable
1384+
A user defined call back function that can dynamically handle operator layouts
1385+
based on Call description. desired_layouts will be ignored if layout_cb is defined.
13801386
13811387
Returns
13821388
-------
13831389
ret : tvm.transform.Pass
13841390
The registered pass for layout conversion.
13851391
"""
1386-
return _ffi_api.ConvertLayout(desired_layouts) # type: ignore
1392+
return _ffi_api.ConvertLayout(desired_layouts, layout_cb) # type: ignore
13871393

13881394

13891395
def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass:

src/relax/transform/convert_layout.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace relax {
3838

3939
using tir::IndexMap;
4040
using tir::Layout;
41+
using LayoutCb = tvm::relax::transform::LayoutCb;
4142

4243
/*!
4344
* \brief Main logic to convert the layout of conv2d. Other ops
@@ -79,8 +80,8 @@ using tir::Layout;
7980
class LayoutConvertMutator : public ExprMutator {
8081
public:
8182
explicit LayoutConvertMutator(
82-
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts)
83-
: desired_layouts_(desired_layouts) {}
83+
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts, LayoutCb layout_cb)
84+
: desired_layouts_(desired_layouts), layout_cb_(layout_cb) {}
8485

8586
private:
8687
ffi::Array<Integer> LayoutToIntegers(const Layout& layout) {
@@ -201,15 +202,21 @@ class LayoutConvertMutator : public ExprMutator {
201202
ffi::Optional<InferLayoutOutput> GetInferLayoutInfo(
202203
const CallNode* call_node,
203204
const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
204-
const VarLayoutMap& var_layout_map) {
205+
const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) {
205206
const OpNode* op_node = call_node->op.as<OpNode>();
206207
if (op_node == nullptr) return std::nullopt;
207208
Op op = Downcast<Op>(ffi::GetRef<Op>(op_node));
208209
const auto attr_map = Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
209210
if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
210211
// If the op has FRelaxInferLayout, and all the input tensors have known ndim
211212
FRelaxInferLayout f = attr_map[op];
212-
return f(ffi::GetRef<Call>(call_node), desired_layouts, var_layout_map);
213+
auto call = ffi::GetRef<Call>(call_node);
214+
if (layout_cb != nullptr) {
215+
auto custom_layouts = layout_cb(call);
216+
return f(call, custom_layouts, var_layout_map);
217+
} else {
218+
return f(call, desired_layouts, var_layout_map);
219+
}
213220
} else {
214221
// Otherwise, we use the default policy.
215222
return std::nullopt;
@@ -218,7 +225,7 @@ class LayoutConvertMutator : public ExprMutator {
218225

219226
void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final {
220227
ffi::Optional<InferLayoutOutput> res =
221-
GetInferLayoutInfo(call_node, desired_layouts_, var_layout_map_);
228+
GetInferLayoutInfo(call_node, desired_layouts_, layout_cb_, var_layout_map_);
222229
ObjectPtr<CallNode> new_call = ffi::make_object<CallNode>(*call_node);
223230
new_call->struct_info_ = std::nullopt;
224231
if (!res.defined() ||
@@ -335,20 +342,23 @@ class LayoutConvertMutator : public ExprMutator {
335342

336343
std::unordered_map<Var, NLayout> var_layout_map_;
337344
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts_;
345+
LayoutCb layout_cb_;
338346
}; // namespace relax
339347

340348
DataflowBlock ConvertLayoutPass(const DataflowBlock& df_block,
341-
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts) {
342-
LayoutConvertMutator mutator(desired_layouts);
349+
ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
350+
LayoutCb layout_cb) {
351+
LayoutConvertMutator mutator(desired_layouts, layout_cb);
343352
return Downcast<DataflowBlock>(mutator.VisitBindingBlock(df_block));
344353
}
345354

346355
namespace transform {
347356

348-
Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts) {
357+
Pass ConvertLayout(ffi::Map<ffi::String, ffi::Array<ffi::String>> desired_layouts,
358+
LayoutCb layout_cb) {
349359
ffi::TypedFunction<DataflowBlock(DataflowBlock, IRModule, PassContext)> pass_func =
350360
[=](DataflowBlock df_block, IRModule m, PassContext pc) {
351-
return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, desired_layouts));
361+
return Downcast<DataflowBlock>(ConvertLayoutPass(df_block, desired_layouts, layout_cb));
352362
};
353363
return CreateDataflowBlockPass(pass_func, 0, "ConvertLayout", {});
354364
}

tests/python/relax/test_transform_convert_layout.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from tvm.script.parser import ir as I, relax as R, tir as T
2222

2323

24-
def verify(input, expected, extra_ops={}):
24+
def verify(input, expected, extra_ops={}, cb=None):
2525
desired_layouts = {"relax.nn.conv2d": ["NHWC", "OHWI"]}
2626
desired_layouts.update(extra_ops)
27-
mod = ConvertLayout(desired_layouts)(input)
27+
mod = ConvertLayout(desired_layouts, cb)(input)
2828
mod = Normalize()(mod)
2929
tvm.ir.assert_structural_equal(mod, expected)
3030

@@ -5487,5 +5487,96 @@ def main(
54875487
verify(Input, Expected)
54885488

54895489

5490+
def test_layout_cb():
5491+
@I.ir_module
5492+
class Input:
5493+
@R.function
5494+
def main(
5495+
x: R.Tensor((2, 4, 28, 28), "float32"),
5496+
w: R.Tensor((4, 4, 3, 3), "float32"),
5497+
bias: R.Tensor((2, 4, 26, 26), "float32"),
5498+
) -> R.Tensor(None, "float32", ndim=4):
5499+
with R.dataflow():
5500+
gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
5501+
gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias)
5502+
gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2)
5503+
gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32")
5504+
R.output(gv4)
5505+
return gv4
5506+
5507+
@I.ir_module
5508+
class Expected:
5509+
@R.function
5510+
def main(
5511+
x: R.Tensor((2, 4, 28, 28), dtype="float32"),
5512+
w: R.Tensor((4, 4, 3, 3), dtype="float32"),
5513+
bias: R.Tensor((2, 4, 26, 26), dtype="float32"),
5514+
) -> R.Tensor((2, 4, 24, 24), dtype="float32"):
5515+
with R.dataflow():
5516+
lv: R.Tensor((2, 1, 28, 28, 4), dtype="float32") = R.layout_transform(
5517+
x,
5518+
index_map=T.index_map(
5519+
lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32"
5520+
),
5521+
)
5522+
lv1: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform(
5523+
w,
5524+
index_map=T.index_map(
5525+
lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32"
5526+
),
5527+
)
5528+
gv: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.conv2d(
5529+
lv,
5530+
lv1,
5531+
strides=[1, 1],
5532+
padding=[0, 0, 0, 0],
5533+
dilation=[1, 1],
5534+
groups=1,
5535+
data_layout="NCHW4c",
5536+
kernel_layout="OIHW4o",
5537+
out_layout="NCHW4c",
5538+
out_dtype="float32",
5539+
)
5540+
lv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.layout_transform(
5541+
bias,
5542+
index_map=T.index_map(
5543+
lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32"
5544+
),
5545+
)
5546+
gv2: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.add(gv, lv2)
5547+
gv3: R.Tensor((2, 1, 26, 26, 4), dtype="float32") = R.nn.relu(gv2)
5548+
lv3: R.Tensor((1, 4, 3, 3, 4), dtype="float32") = R.layout_transform(
5549+
w,
5550+
index_map=T.index_map(
5551+
lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32"
5552+
),
5553+
)
5554+
lv4: R.Tensor((2, 1, 24, 24, 4), dtype="float32") = R.nn.conv2d(
5555+
gv3,
5556+
lv3,
5557+
strides=[1, 1],
5558+
padding=[0, 0, 0, 0],
5559+
dilation=[1, 1],
5560+
groups=1,
5561+
data_layout="NCHW4c",
5562+
kernel_layout="OIHW4o",
5563+
out_layout="NCHW4c",
5564+
out_dtype="float32",
5565+
)
5566+
gv4: R.Tensor((2, 4, 24, 24), dtype="float32") = R.layout_transform(
5567+
lv4,
5568+
index_map=T.index_map(
5569+
lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32"
5570+
),
5571+
)
5572+
R.output(gv4)
5573+
return gv4
5574+
5575+
def layout_cb(call: tvm.relax.Call):
5576+
return {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}
5577+
5578+
verify(Input, Expected, cb=layout_cb)
5579+
5580+
54905581
if __name__ == "__main__":
54915582
tvm.testing.main()

0 commit comments

Comments
 (0)