Skip to content

Commit 8b1b87b

Browse files
authored
[INTEL_HPU] Added support for in-place operations in the index_copy (#1496)
1 parent 4a1e884 commit 8b1b87b

File tree

3 files changed

+76
-89
lines changed

3 files changed

+76
-89
lines changed

backends/intel_hpu/custom_ops/src/index_copy.cc

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,35 @@ class IndexCopy : public HpuOperator {
2929
auto inputs = ct.GetTensors();
3030
auto outputs = ct.GetTensors(false);
3131

32+
synSectionHandle section = createSection();
33+
3234
std::vector<synTensor> syn_inputs;
33-
for (size_t i = 0; i < inputs.size(); i++) {
34-
syn_inputs.push_back(createTensor(inputs[i].dims.size(),
35-
inputs[i].type,
36-
inputs[i].dims,
37-
true,
38-
inputs[i].name));
39-
}
35+
syn_inputs.push_back(createTensor(inputs[0].dims.size(),
36+
inputs[0].type,
37+
inputs[0].dims,
38+
true,
39+
inputs[0].name,
40+
section));
41+
42+
syn_inputs.push_back(createTensor(inputs[1].dims.size(),
43+
inputs[1].type,
44+
inputs[1].dims,
45+
true,
46+
inputs[1].name));
47+
48+
syn_inputs.push_back(createTensor(inputs[2].dims.size(),
49+
inputs[2].type,
50+
inputs[2].dims,
51+
true,
52+
inputs[2].name));
4053

4154
std::vector<synTensor> syn_outputs;
42-
for (size_t i = 0; i < outputs.size(); i++) {
43-
syn_outputs.push_back(createTensor(outputs[i].dims.size(),
44-
outputs[i].type,
45-
outputs[i].dims,
46-
true,
47-
outputs[i].name));
48-
}
55+
syn_outputs.push_back(createTensor(outputs[0].dims.size(),
56+
outputs[0].type,
57+
outputs[0].dims,
58+
true,
59+
outputs[0].name,
60+
section));
4961

5062
std::string guid = guid_ + "_" + SynDataTypeToStr(outputs[0].type);
5163
synStatus status = synNodeCreate(graphHandle_,
@@ -73,19 +85,13 @@ void IndexCopyKernel(const Context& dev_ctx,
7385
const phi::DenseTensor& input,
7486
const phi::Scalar& dim,
7587
const phi::DenseTensor& index,
76-
const phi::DenseTensor& source,
77-
phi::DenseTensor* out) {
78-
dev_ctx.template Alloc<T>(out);
79-
if (out->numel() == 0) {
80-
return;
81-
}
82-
88+
const phi::DenseTensor& source) {
8389
ConvertTensors ct;
8490
ct.Add(input);
8591
ct.Add(index);
8692
ct.Add(source);
8793

88-
ct.Add(out, false);
94+
ct.Add(input, false);
8995

9096
std::vector<DIMS> inputs_dims = ct.GetDims();
9197
ns_IndexCopy::Params params{};
@@ -117,48 +123,39 @@ void CallIndexCopyKernel(const Context& dev_ctx,
117123
const phi::DenseTensor& input,
118124
const phi::Scalar& dim,
119125
const phi::DenseTensor& index,
120-
const phi::DenseTensor& source,
121-
phi::DenseTensor* out) {
126+
const phi::DenseTensor& source) {
122127
if (input.dtype() == phi::DataType::FLOAT32) {
123-
custom_kernel::IndexCopyKernel<float>(
124-
dev_ctx, input, dim, index, source, out);
128+
custom_kernel::IndexCopyKernel<float>(dev_ctx, input, dim, index, source);
125129
} else if (input.dtype() == phi::DataType::FLOAT16) {
126130
custom_kernel::IndexCopyKernel<phi::dtype::float16>(
127-
dev_ctx, input, dim, index, source, out);
131+
dev_ctx, input, dim, index, source);
128132
} else if (input.dtype() == phi::DataType::BFLOAT16) {
129133
custom_kernel::IndexCopyKernel<phi::dtype::bfloat16>(
130-
dev_ctx, input, dim, index, source, out);
134+
dev_ctx, input, dim, index, source);
131135
} else {
132136
throw std::runtime_error("Unsupported data type for IndexCopyKernel");
133137
}
134138
}
135139

136-
std::vector<paddle::Tensor> IndexCopyForward(const paddle::Tensor& input,
137-
const int dim,
138-
const paddle::Tensor& index,
139-
const paddle::Tensor& source) {
140+
void IndexCopyForward(const paddle::Tensor& input,
141+
const int dim,
142+
const paddle::Tensor& index,
143+
const paddle::Tensor& source) {
140144
auto dev_ctx = static_cast<const phi::CustomContext*>(
141145
paddle::experimental::DeviceContextPool::Instance().Get(input.place()));
142146

143147
auto input_tensor = static_cast<phi::DenseTensor*>(input.impl().get());
144148
auto index_tensor = static_cast<const phi::DenseTensor*>(index.impl().get());
145149
auto source_tensor =
146150
static_cast<const phi::DenseTensor*>(source.impl().get());
147-
auto out_tensor = std::make_shared<phi::DenseTensor>();
148-
out_tensor->Resize(input_tensor->dims());
149-
150-
CallIndexCopyKernel(*dev_ctx,
151-
*input_tensor,
152-
phi::Scalar(dim),
153-
*index_tensor,
154-
*source_tensor,
155-
out_tensor.get());
156151

157-
return {paddle::Tensor(out_tensor)};
152+
CallIndexCopyKernel(
153+
*dev_ctx, *input_tensor, phi::Scalar(dim), *index_tensor, *source_tensor);
158154
}
159155

160156
PD_BUILD_OP(index_copy)
161157
.Inputs({"input", "index", "source"})
162158
.Outputs({"out"})
163159
.Attrs({"dim: int"})
160+
.SetInplaceMap({{"input", "out"}})
164161
.SetKernelFn(PD_KERNEL(IndexCopyForward));

backends/intel_hpu/custom_ops/tests/test_index_copy.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def index_copy_torch(input, dim, index, source, dtype):
3030
"int32": torch.int32,
3131
}
3232
torch_dtype = dtype_map[dtype]
33-
input_tensor = torch.tensor(input, dtype=torch_dtype)
34-
index_tensor = torch.tensor(index, dtype=torch.int64)
35-
source_tensor = torch.tensor(source, dtype=torch_dtype)
33+
input_tensor = torch.tensor(input).clone().detach().to(dtype=torch_dtype)
34+
index_tensor = torch.tensor(index).clone().detach().to(dtype=torch.int64)
35+
source_tensor = torch.tensor(source).clone().detach().to(dtype=torch_dtype)
3636
output = torch.index_copy(
3737
input=input_tensor, dim=dim, index=index_tensor, source=source_tensor
3838
)
@@ -72,13 +72,13 @@ def check_result(self, torch_res, ops_res):
7272
np.testing.assert_allclose(torch_res, ops_res, rtol=rtol, atol=atol)
7373

7474
def index_copy_custom(self, input, dim, index, source):
75-
input_tensor = paddle.to_tensor(input, dtype=self.dtype)
76-
index_tensor = paddle.to_tensor(index, dtype="int64")
77-
source_tensor = paddle.to_tensor(source, dtype=self.dtype)
78-
out = paddlenlp_ops.index_copy(
75+
input_tensor = paddle.to_tensor(input, dtype=self.dtype).clone()
76+
index_tensor = paddle.to_tensor(index, dtype="int64").clone()
77+
source_tensor = paddle.to_tensor(source, dtype=self.dtype).clone()
78+
paddlenlp_ops.index_copy(
7979
input=input_tensor, dim=dim, index=index_tensor, source=source_tensor
8080
)
81-
return out
81+
return input_tensor
8282

8383
def prepare_input(
8484
self, batch_size=16, num_heads=32, seq_length=256, head_dim=64, dim=0, index=0
@@ -118,26 +118,26 @@ def test_index_copy_dim0_index0(self):
118118
input, index, source, dim = self.prepare_input(dim=0, index=0)
119119
custom_res = self.index_copy_custom(input, dim, index, source)
120120
torch_res = index_copy_torch(input, dim, index, source, dtype=self.dtype)
121-
self.check_result(torch_res.numpy(), custom_res.numpy())
121+
self.check_result(torch_res.numpy(), custom_res)
122122

123123
def test_index_copy_dim0_index1(self):
124124
input, index, source, dim = self.prepare_input(dim=0, index=1)
125125
custom_res = self.index_copy_custom(input, dim, index, source)
126126
torch_res = index_copy_torch(input, dim, index, source, dtype=self.dtype)
127-
self.check_result(torch_res.numpy(), custom_res.numpy())
127+
self.check_result(torch_res.numpy(), custom_res)
128128

129129
def test_index_copy_dim0_index_max(self):
130130
index = max(self.num_heads - 1, 0)
131131
input, index, source, dim = self.prepare_input(dim=0, index=index)
132132
custom_res = self.index_copy_custom(input, dim, index, source)
133133
torch_res = index_copy_torch(input, dim, index, source, dtype=self.dtype)
134-
self.check_result(torch_res.numpy(), custom_res.numpy())
134+
self.check_result(torch_res.numpy(), custom_res)
135135

136136
def test_index_copy_dim1_index0(self):
137137
input, index, source, dim = self.prepare_input(dim=1, index=0)
138138
custom_res = self.index_copy_custom(input, dim, index, source)
139139
torch_res = index_copy_torch(input, dim, index, source, dtype=self.dtype)
140-
self.check_result(torch_res.numpy(), custom_res.numpy())
140+
self.check_result(torch_res.numpy(), custom_res)
141141

142142
def test_index_copy_dim1_index1(self):
143143
input, index, source, dim = self.prepare_input(dim=1, index=1)

backends/intel_hpu/tests/test_kvcache.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
# limitations under the License.
1414

1515
import paddle
16+
import paddlenlp_ops
1617

1718
paddle.set_device("intel_hpu")
1819
# paddle.set_device("cpu")
1920

2021

2122
class KVCache(paddle.nn.Layer):
22-
def __init__(self):
23+
def __init__(self, cache=None, inp_seq_len=-1):
2324
super(KVCache, self).__init__()
24-
self.cache = None
25-
self.inp_seq_len = -1
25+
print(
26+
f"`Paddle KVCache init` cache: {cache.shape if cache is not None else 'None'}, inp_seq_len: {inp_seq_len}"
27+
)
28+
self.cache = cache
29+
self.inp_seq_len = inp_seq_len
2630

2731
def allocate(self, inp_seq_len, dtype, shape):
2832
if self.cache is None or self.cache.shape != shape:
@@ -49,20 +53,7 @@ def update(prev, cur, dim, idx, inp_seq_len):
4953
return orig_cur
5054
if idx is not None:
5155
# prev.index_copy_(dim, idx - 1, cur)
52-
if dim == 0:
53-
prev.scatter_(idx, cur)
54-
else:
55-
times, temp_shape, temp_index = (
56-
paddle.prod(paddle.to_tensor(prev.shape[:dim])),
57-
prev.shape,
58-
idx,
59-
)
60-
prev, new_t = prev.reshape([-1] + temp_shape[dim + 1 :]), cur.reshape(
61-
[-1] + temp_shape[dim + 1 :]
62-
)
63-
for i in range(1, times):
64-
temp_index = paddle.concat([temp_index, idx + temp_shape[dim] * i])
65-
prev.scatter_(temp_index, new_t).reshape_(temp_shape)
56+
paddlenlp_ops.index_copy(input=prev, dim=dim, index=idx - 1, source=cur)
6657
return prev
6758
else:
6859
return paddle.concat((prev, cur), dim=dim)
@@ -77,35 +68,34 @@ def forward(self, cur, dim, idx):
7768

7869

7970
batch_size = 1
80-
num_key_value_heads = 32
81-
max_seq_len = 1024
82-
head_dim = 128
71+
num_key_value_heads = 2
72+
max_seq_len = 16
73+
head_dim = 4
8374

8475
# paddle case
8576
cache_shape = (batch_size, num_key_value_heads, max_seq_len, head_dim)
8677
dtype = "float32"
8778

88-
inp_seq_len = 128
79+
inp_seq_len = 2
8980

90-
k_cache = KVCache()
91-
k_cache.allocate(inp_seq_len, dtype, cache_shape)
81+
static_cache = paddle.zeros(cache_shape, dtype=dtype)
82+
k_cache = KVCache(static_cache, inp_seq_len)
83+
# k_cache = KVCache()
84+
# k_cache.allocate(inp_seq_len, dtype, cache_shape)
9285

93-
key_states = paddle.rand(
94-
(batch_size, num_key_value_heads, inp_seq_len, head_dim), dtype=dtype
86+
key_states = paddle.full(
87+
(batch_size, num_key_value_heads, inp_seq_len, head_dim), -1, dtype=dtype
9588
)
96-
9789
token_idx = paddle.to_tensor([0], dtype="int64")
98-
prefill = k_cache(key_states, 2, token_idx)
90+
prefill = k_cache(cur=key_states, dim=2, idx=token_idx)
91+
print(f"Paddle KVCache prefill:{prefill}")
9992

100-
print((prefill == k_cache.cache[:, :, :inp_seq_len, :]).all())
93+
for i in range(inp_seq_len + 1, max_seq_len + 1):
94+
token_idx = paddle.to_tensor([i], dtype="int64")
95+
key_state = paddle.ones((batch_size, num_key_value_heads, 1, head_dim), dtype=dtype)
96+
decode = k_cache(cur=key_state, dim=2, idx=token_idx)
97+
print(f"Paddle KVCache decode:{decode}")
10198

102-
inp_seq_len = 1
103-
token_idx = paddle.to_tensor([128], dtype="int64")
104-
key_state = paddle.ones(
105-
(batch_size, num_key_value_heads, inp_seq_len, head_dim), dtype=dtype
106-
)
107-
decode = k_cache(key_state, 2, token_idx)
108-
print((key_state == decode[:, :, token_idx, :]).all())
10999

110100
if 0:
111101
# torch case

0 commit comments

Comments
 (0)