Skip to content

Commit 1dd8b6b

Browse files
author
Judd
committed
update RobertaEmbedding for Vulkan compatibility.
1 parent 762b1b5 commit 1dd8b6b

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

src/layers.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -630,11 +630,7 @@ namespace chatllm
630630
{
631631
int qlen = (int)input->ne[0];
632632

633-
// quick fix for `before_initial_run`
634-
if (n_past + pad_index + qlen > indices->ne[0])
635-
n_past = (int)indices->ne[0] - qlen - pad_index;
636-
637-
ggml::tensor *idx = ggml::view_1d(ctx, indices, qlen, (n_past + pad_index) * ggml::element_size(indices));
633+
ggml::tensor *idx = ggml::view_1d(ctx, indices, qlen, 0);
638634

639635
ggml::tensor *output1 = ggml::get_rows(ctx, word_weight, input);
640636
ggml::tensor *output2 = ggml::get_rows(ctx, position_weight, idx);

src/layers.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,13 +383,13 @@ namespace chatllm
383383
: word_weight(ggml::new_tensor_2d(ctx, ctx->dtype, embedding_dim, num_embeddings)),
384384
position_weight(ggml::new_tensor_2d(ctx, ctx->dtype, embedding_dim, pos_max)),
385385
indices(ggml::new_tensor_1d(ctx, GGML_TYPE_I32, pos_max)),
386-
ln(ctx, embedding_dim),
387-
pad_index(2)
386+
ln(ctx, embedding_dim)
388387
{
388+
const int pad_index = 2;
389389
std::vector<int> v_indices;
390390
v_indices.resize(pos_max);
391391
for (int i = 0; i < pos_max; i++)
392-
v_indices[i] = i;
392+
v_indices[i] = pad_index + i;
393393

394394
ctx->get_allocator()->alloc(indices);
395395
Backend::write_tensor_data(indices, v_indices.data());
@@ -411,7 +411,6 @@ namespace chatllm
411411
ggml::tensor *position_weight;
412412
ggml::tensor *indices;
413413
LayerNorm ln;
414-
int pad_index;
415414
};
416415

417416
class RMSNorm : public Block

0 commit comments

Comments
 (0)