Skip to content

Commit 976203c

Browse files
authored
[XPU ]fix text_image_gather_scatter in cudagraph mode(#6049)
1 parent 20074d3 commit 976203c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/text_image_gather_scatter.xpu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ static __device__ inline void text_image_gather(
6060
if (token_type == 0) {
6161
text_image_input = text_input;
6262
text_image_index = text_index;
63-
} else {
63+
} else if (token_type == 1) {
6464
text_image_input = image_input;
6565
text_image_index = image_index;
66+
} else {
67+
continue;
6668
}
6769
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
6870
int input_offset = i * hidden_size;
@@ -132,9 +134,11 @@ static __device__ inline void text_image_scatter(
132134
if (token_type == 0) {
133135
text_image_input = text_input;
134136
text_image_index = text_index;
135-
} else {
137+
} else if (token_type == 1) {
136138
text_image_input = image_input;
137139
text_image_index = image_index;
140+
} else {
141+
continue;
138142
}
139143
GM2LM(text_image_index + i, &text_image_token_idx, sizeof(int));
140144
int input_offset = i * hidden_size;

0 commit comments

Comments
 (0)