Skip to content

Commit 56b4795

Browse files
authored
model-conversion : add support for SentenceTransformers (#16387)
* model-conversion : add support for SentenceTransformers This commit adds support for models that use SentenceTransformer layers. The motivation for this is that if converted model includes any of the numbered layers specified in the original models repository then these changes enable these models to be used and verified. Currently the model-conversion only support the base model output without any of the additional transformation layers. Usage: Convert the model that also includes the SentenceTransformer layers: ```console (venv) $ export EMBEDDING_MODEL_PATH="~/google/embeddinggemma-300M" (venv) make embedding-convert-model ``` Verify the produced embeddings from the converted model against the original model embeddings: ```console (venv) make embedding-verify-logits-st ``` The original model can be run using SentenceTransformer: ```console (venv) make embedding-run-original-model-st ``` Run the converted model using "SentenceTransformer" layers whic enables pooling and normalization: ```console (venv) make embedding-run-converted-model-st ``` * add model-conversion example requirements * add support for -st flag in embedding model conversion This commit add support for the -st flag in the embedding model conversion script. This will enable models to be converted using sentence transformers dense layers.
1 parent 2c0d875 commit 56b4795

File tree

9 files changed

+307
-150
lines changed

9 files changed

+307
-150
lines changed

examples/model-conversion/Makefile

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,39 @@ embedding-convert-model:
116116
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
117117
./scripts/embedding/convert-model.sh
118118

119+
embedding-convert-model-st:
120+
$(call validate_embedding_model_path,embedding-convert-model-st)
121+
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
122+
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
123+
./scripts/embedding/convert-model.sh -st
124+
119125
embedding-run-original-model:
120126
$(call validate_embedding_model_path,embedding-run-original-model)
121127
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
128+
USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \
122129
./scripts/embedding/run-original-model.py \
123-
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
130+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
131+
$(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers)
132+
133+
embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
134+
embedding-run-original-model-st: embedding-run-original-model
124135

125136
embedding-run-converted-model:
126137
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
127-
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
138+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
139+
$(if $(USE_POOLING),--pooling)
140+
141+
embedding-run-converted-model-st: USE_POOLING=1
142+
embedding-run-converted-model-st: embedding-run-converted-model
128143

129144
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
130145
@./scripts/embedding/compare-embeddings-logits.sh \
131146
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
132147

148+
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
149+
@./scripts/embedding/compare-embeddings-logits.sh \
150+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
151+
133152
embedding-inspect-original-model:
134153
$(call validate_embedding_model_path,embedding-inspect-original-model)
135154
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}

examples/model-conversion/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary
189189
file containing logits which will be used for comparison with the converted
190190
model, and the other is a text file which allows for manual visual inspection.
191191

192+
#### Using SentenceTransformer with numbered layers
193+
For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
194+
03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
195+
196+
```console
197+
# Run original model with SentenceTransformer (applies all numbered layers)
198+
(venv) $ make embedding-run-original-model-st
199+
200+
# Run converted model with pooling enabled
201+
(venv) $ make embedding-run-converted-model-st
202+
```
203+
204+
This will use the SentenceTransformer library to load and run the model, which
205+
automatically applies all the numbered layers in the correct order. This is
206+
particularly useful when comparing with models that should include these
207+
additional transformation layers beyond just the base model output.
208+
192209
### Model conversion
193210
After updates have been made to [gguf-py](../../gguf-py) to add support for the
194211
new model the model can be converted to GGUF format using the following command:
@@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits:
208225
(venv) $ make embedding-verify-logits
209226
```
210227

228+
For models with SentenceTransformer layers, use the `-st` verification target:
229+
```console
230+
(venv) $ make embedding-verify-logits-st
231+
```
232+
This convenience target automatically runs both the original model with SentenceTransformer
233+
and the converted model with pooling enabled, then compares the results.
234+
211235
### llama-server verification
212236
To verify that the converted model works with llama-server, the following
213237
command can be used:

examples/model-conversion/logits.cpp

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#include "llama.h"
2+
#include "common.h"
3+
4+
25
#include <cstdio>
36
#include <cstring>
47
#include <string>
@@ -8,7 +11,10 @@
811

912
static void print_usage(int, char ** argv) {
1013
printf("\nexample usage:\n");
11-
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]);
14+
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
15+
printf("\n");
16+
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
17+
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
1218
printf("\n");
1319
}
1420

@@ -17,6 +23,8 @@ int main(int argc, char ** argv) {
1723
std::string prompt = "Hello, my name is";
1824
int ngl = 0;
1925
bool embedding_mode = false;
26+
bool pooling_enabled = false;
27+
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
2028

2129
{
2230
int i = 1;
@@ -41,9 +49,13 @@ int main(int argc, char ** argv) {
4149
return 1;
4250
}
4351
} else if (strcmp(argv[i], "-embd-mode") == 0) {
52+
embedding_mode = true;
53+
} else if (strcmp(argv[i], "-pooling") == 0) {
54+
pooling_enabled = true;
55+
} else if (strcmp(argv[i], "-embd-norm") == 0) {
4456
if (i + 1 < argc) {
4557
try {
46-
embedding_mode = true;
58+
embd_norm = std::stoi(argv[++i]);
4759
} catch (...) {
4860
print_usage(argc, argv);
4961
return 1;
@@ -112,7 +124,7 @@ int main(int argc, char ** argv) {
112124
ctx_params.no_perf = false;
113125
if (embedding_mode) {
114126
ctx_params.embeddings = true;
115-
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
127+
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
116128
ctx_params.n_ubatch = ctx_params.n_batch;
117129
}
118130

@@ -143,17 +155,27 @@ int main(int argc, char ** argv) {
143155
return 1;
144156
}
145157

146-
float * logits;
147-
int n_logits;
158+
float * data_ptr;
159+
int data_size;
148160
const char * type;
161+
std::vector<float> embd_out;
149162

150163
if (embedding_mode) {
151-
logits = llama_get_embeddings(ctx);
152-
n_logits = llama_model_n_embd(model) * batch.n_tokens;
164+
const int n_embd = llama_model_n_embd(model);
165+
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
166+
const int n_embeddings = n_embd * n_embd_count;
167+
float * embeddings;
153168
type = "-embeddings";
154169

155-
const int n_embd = llama_model_n_embd(model);
156-
const int n_embd_count = batch.n_tokens;
170+
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
171+
embeddings = llama_get_embeddings_seq(ctx, 0);
172+
embd_out.resize(n_embeddings);
173+
printf("Normalizing embeddings using norm: %d\n", embd_norm);
174+
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
175+
embeddings = embd_out.data();
176+
} else {
177+
embeddings = llama_get_embeddings(ctx);
178+
}
157179

158180
printf("Embedding dimension: %d\n", n_embd);
159181
printf("\n");
@@ -164,43 +186,49 @@ int main(int argc, char ** argv) {
164186

165187
// Print first 3 values
166188
for (int i = 0; i < 3 && i < n_embd; i++) {
167-
printf("%9.6f ", logits[j * n_embd + i]);
189+
printf("%9.6f ", embeddings[j * n_embd + i]);
168190
}
169191

170192
printf(" ... ");
171193

172194
// Print last 3 values
173195
for (int i = n_embd - 3; i < n_embd; i++) {
174196
if (i >= 0) {
175-
printf("%9.6f ", logits[j * n_embd + i]);
197+
printf("%9.6f ", embeddings[j * n_embd + i]);
176198
}
177199
}
178200

179201
printf("\n");
180202
}
181203
printf("\n");
182204

183-
printf("Embeddings size: %d\n", n_logits);
205+
printf("Embeddings size: %d\n", n_embeddings);
206+
207+
data_ptr = embeddings;
208+
data_size = n_embeddings;
184209
} else {
185-
logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
186-
n_logits = llama_vocab_n_tokens(vocab);
210+
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
211+
const int n_logits = llama_vocab_n_tokens(vocab);
187212
type = "";
188213
printf("Vocab size: %d\n", n_logits);
214+
215+
data_ptr = logits;
216+
data_size = n_logits;
189217
}
190218

191219
std::filesystem::create_directory("data");
192220

193-
// Save logits to binary file
221+
// Save data to binary file
194222
char bin_filename[512];
195223
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
196-
printf("Saving logits to %s\n", bin_filename);
224+
printf("Saving data to %s\n", bin_filename);
197225

198226
FILE * f = fopen(bin_filename, "wb");
199227
if (f == NULL) {
200228
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
201229
return 1;
202230
}
203-
fwrite(logits, sizeof(float), n_logits, f);
231+
fwrite(data_ptr, sizeof(float), data_size, f);
204232
fclose(f);
205233

206234
// Also save as text for debugging
@@ -211,27 +239,27 @@ int main(int argc, char ** argv) {
211239
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
212240
return 1;
213241
}
214-
for (int i = 0; i < n_logits; i++) {
215-
fprintf(f, "%d: %.6f\n", i, logits[i]);
242+
for (int i = 0; i < data_size; i++) {
243+
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
216244
}
217245
fclose(f);
218246

219247
if (!embedding_mode) {
220248
printf("First 10 logits: ");
221-
for (int i = 0; i < 10 && i < n_logits; i++) {
222-
printf("%.6f ", logits[i]);
249+
for (int i = 0; i < 10 && i < data_size; i++) {
250+
printf("%.6f ", data_ptr[i]);
223251
}
224252
printf("\n");
225253

226254
printf("Last 10 logits: ");
227-
for (int i = n_logits - 10; i < n_logits; i++) {
228-
if (i >= 0) printf("%.6f ", logits[i]);
255+
for (int i = data_size - 10; i < data_size; i++) {
256+
if (i >= 0) printf("%.6f ", data_ptr[i]);
229257
}
230258
printf("\n\n");
231259
}
232260

233-
printf("Logits saved to %s\n", bin_filename);
234-
printf("Logits saved to %s\n", txt_filename);
261+
printf("Data saved to %s\n", bin_filename);
262+
printf("Data saved to %s\n", txt_filename);
235263

236264
llama_free(ctx);
237265
llama_model_free(model);

examples/model-conversion/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ torchvision
44
transformers
55
huggingface-hub
66
accelerate
7+
sentence-transformers

examples/model-conversion/scripts/embedding/convert-model.sh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22

33
set -e
44

5+
# Parse command line arguments
6+
SENTENCE_TRANSFORMERS=""
7+
while [[ $# -gt 0 ]]; do
8+
case $1 in
9+
-st|--sentence-transformers)
10+
SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules"
11+
shift
12+
;;
13+
*)
14+
echo "Unknown option: $1"
15+
exit 1
16+
;;
17+
esac
18+
done
19+
520
MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}"
621
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
722
TYPE="${OUTTYPE:-f16}"
@@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}"
1530
python ../../convert_hf_to_gguf.py --verbose \
1631
${EMBEDDING_MODEL_PATH} \
1732
--outfile ${CONVERTED_MODEL} \
18-
--outtype ${TYPE}
33+
--outtype ${TYPE} \
34+
${SENTENCE_TRANSFORMERS}
1935

2036
echo ""
2137
echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:"

examples/model-conversion/scripts/embedding/run-converted-model.sh

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@ set -e
55
# Parse command line arguments
66
CONVERTED_MODEL=""
77
PROMPTS_FILE=""
8+
USE_POOLING=""
89

910
while [[ $# -gt 0 ]]; do
1011
case $1 in
1112
-p|--prompts-file)
1213
PROMPTS_FILE="$2"
1314
shift 2
1415
;;
16+
--pooling)
17+
USE_POOLING="1"
18+
shift
19+
;;
1520
*)
1621
if [ -z "$CONVERTED_MODEL" ]; then
1722
CONVERTED_MODEL="$1"
@@ -47,4 +52,8 @@ echo $CONVERTED_MODEL
4752

4853
cmake --build ../../build --target llama-logits -j8
4954
# TODO: update logits.cpp to accept a --file/-f option for the prompt
50-
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
55+
if [ -n "$USE_POOLING" ]; then
56+
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
57+
else
58+
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
59+
fi

0 commit comments

Comments
 (0)