From 89ca978dd751e061ff94b3846c75f12a468dfe10 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 21 Apr 2025 16:57:09 -0500 Subject: [PATCH] [tuner] use python binding to get indexing maps for root op Signed-off-by: Bangtian Liu --- tuner/tuner/dispatch_parser.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index def8395280..0969712048 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -57,15 +57,8 @@ def get_problem_size(self) -> ProblemSize: contraction_dims = linalg.infer_contraction_dimensions(root_op) assert contraction_dims, "no contraction dimensions" - # TODO(Bangtian): Expose Python bindings for getting indexing maps. - indexing_maps_attr = None - for attr in root_op.opview.attributes: - if attr.name == "indexing_maps" and isinstance(attr.attr, ir.ArrayAttr): - indexing_maps_attr = attr.attr - break - - assert indexing_maps_attr, "indexing_maps attribute not found" - maps = [attr.value for attr in indexing_maps_attr] + res_maps = linalg.get_indexing_maps(root_op) + maps = [map_attr.value for map_attr in res_maps] lhs_dims = get_map_result_dim_positions(maps[0]) rhs_dims = get_map_result_dim_positions(maps[1]) res_dims = get_map_result_dim_positions(maps[2])