Skip to content

Commit 56b15ef

Browse files
authored
[INTEL_HPU] fix issue that graph recompile not triggered by index dtype change (#1861)
1 parent 3a732ec commit 56b15ef

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

backends/intel_hpu/kernels/index_select_kernel.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "habanalabs/perf_lib_layer_params.h"
1616
#include "kernels/funcs.h"
17+
#include "kernels/hpu_funcs.h"
1718
#include "kernels/hpu_operator.h"
1819
#include "utils/utils.h"
1920

@@ -85,12 +86,21 @@ void IndexSelectKernel(const Context& dev_ctx,
8586
dim += x.dims().size();
8687
}
8788

89+
std::string op_name = "IndexSelectKernel";
90+
if (index.dtype() == phi::DataType::INT32) {
91+
op_name += "_int32";
92+
} else if (index.dtype() == phi::DataType::INT64) {
93+
op_name += "_int64";
94+
} else {
95+
throw std::runtime_error(
96+
"index_select supports only int64 and int32 for index!");
97+
}
98+
8899
OpCacheOperator op_info;
89100
IndexSelectParams params;
90101
params.params.axis = static_cast<int32_t>(x.dims().size()) - 1 - dim;
91102
std::vector<DIMS> inputs_dims = ct.GetDims();
92-
op_info.prepareOpInfo<T, IndexSelectParams>(
93-
"IndexSelectKernel", inputs_dims, &params);
103+
op_info.prepareOpInfo<T, IndexSelectParams>(op_name, inputs_dims, &params);
94104
auto recipe = op_info.GetRecipe();
95105
if (recipe == nullptr) {
96106
IndexSelect op;

backends/intel_hpu/tests/unittests/test_index_select.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,18 @@ def config(self):
114114
self.index_type = np.int64
115115

116116

117+
class TestMixOfInt32andInt64(unittest.TestCase):
118+
def test_mix_int32_int64(self):
119+
paddle.disable_static()
120+
x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
121+
indices_32 = paddle.to_tensor([0, 2], dtype=paddle.int32)
122+
result1 = paddle.index_select(x, axis=0, index=indices_32)
123+
124+
indices_64 = paddle.to_tensor([0, 2], dtype=paddle.int64)
125+
result2 = paddle.index_select(x, axis=0, index=indices_64)
126+
self.assertTrue(np.array_equal(result1.numpy(), result2.numpy()))
127+
paddle.enable_static()
128+
129+
117130
if __name__ == "__main__":
118131
unittest.main()

0 commit comments

Comments
 (0)