Skip to content

Commit ec9ca69

Browse files
author
lexasub
committed
train: add simple loading already tokenized data from parquet dataset
1 parent 36e257d commit ec9ca69

File tree

8 files changed

+143
-17
lines changed

8 files changed

+143
-17
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
8686
# 3rd party libs
8787
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
8888
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
89+
option(LLAMA_PARQUET "Enable Parquet dataset support via Arrow/Parquet C++" OFF)
8990

9091
# Required for relocatable CMake package
9192
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)

common/arg.cpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,14 +1471,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14711471
[](common_params & params) {
14721472
params.ctx_shift = false;
14731473
}
1474-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
1474+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
14751475
add_opt(common_arg(
14761476
{"--chunks"}, "N",
14771477
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
14781478
[](common_params & params, int value) {
14791479
params.n_chunks = value;
14801480
}
1481-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
1481+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RETRIEVAL}));
14821482
add_opt(common_arg(
14831483
{"-fa", "--flash-attn"},
14841484
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
@@ -2116,70 +2116,70 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21162116
[](common_params & params) {
21172117
params.hellaswag = true;
21182118
}
2119-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2119+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21202120
add_opt(common_arg(
21212121
{"--hellaswag-tasks"}, "N",
21222122
string_format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks),
21232123
[](common_params & params, int value) {
21242124
params.hellaswag_tasks = value;
21252125
}
2126-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2126+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21272127
add_opt(common_arg(
21282128
{"--winogrande"},
21292129
"compute Winogrande score over random tasks from datafile supplied with -f",
21302130
[](common_params & params) {
21312131
params.winogrande = true;
21322132
}
2133-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2133+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21342134
add_opt(common_arg(
21352135
{"--winogrande-tasks"}, "N",
21362136
string_format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks),
21372137
[](common_params & params, int value) {
21382138
params.winogrande_tasks = value;
21392139
}
2140-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2140+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21412141
add_opt(common_arg(
21422142
{"--multiple-choice"},
21432143
"compute multiple choice score over random tasks from datafile supplied with -f",
21442144
[](common_params & params) {
21452145
params.multiple_choice = true;
21462146
}
2147-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2147+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21482148
add_opt(common_arg(
21492149
{"--multiple-choice-tasks"}, "N",
21502150
string_format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks),
21512151
[](common_params & params, int value) {
21522152
params.multiple_choice_tasks = value;
21532153
}
2154-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2154+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21552155
add_opt(common_arg(
21562156
{"--kl-divergence"},
21572157
"computes KL-divergence to logits provided via --kl-divergence-base",
21582158
[](common_params & params) {
21592159
params.kl_divergence = true;
21602160
}
2161-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2161+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21622162
add_opt(common_arg(
21632163
{"--save-all-logits", "--kl-divergence-base"}, "FNAME",
21642164
"set logits file",
21652165
[](common_params & params, const std::string & value) {
21662166
params.logits_file = value;
21672167
}
2168-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2168+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21692169
add_opt(common_arg(
21702170
{"--ppl-stride"}, "N",
21712171
string_format("stride for perplexity calculation (default: %d)", params.ppl_stride),
21722172
[](common_params & params, int value) {
21732173
params.ppl_stride = value;
21742174
}
2175-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2175+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21762176
add_opt(common_arg(
21772177
{"--ppl-output-type"}, "<0|1>",
21782178
string_format("output type for perplexity calculation (default: %d)", params.ppl_output_type),
21792179
[](common_params & params, int value) {
21802180
params.ppl_output_type = value;
21812181
}
2182-
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2182+
).set_examples({LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}));
21832183
add_opt(common_arg(
21842184
{"-dt", "--defrag-thold"}, "N",
21852185
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),
@@ -3470,5 +3470,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
34703470
})
34713471
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
34723472

3473+
#ifdef LLAMA_PARQUET
3474+
add_opt(common_arg(
3475+
{"--dataset-format"}, "text",
3476+
string_format("Dataset format: text or parquet (requires LLAMA_PARQUET)"),
3477+
[](common_params & params, const std::string & format) {
3478+
params.dataset_format = format; //or parquet//TODO ENUM CLASS
3479+
}
3480+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
3481+
3482+
add_opt(common_arg(
3483+
{"--parquet-path"}, "parquet.parquet",
3484+
string_format("Parquet path"),
3485+
[](common_params & params, const std::string & filepath) {//TODO -read dir
3486+
params.parquet_path = filepath;
3487+
}
3488+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
3489+
3490+
add_opt(common_arg(
3491+
{"--tokens-column"}, "tokens",
3492+
string_format("Name of tokens column (list<int32>) in Parquet file"),
3493+
[](common_params & params, const std::string & column) {
3494+
params.tokens_column = column;
3495+
}
3496+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
3497+
#endif
34733498
return ctx_arg;
34743499
}

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ enum llama_example {
8787
LLAMA_EXAMPLE_FINETUNE,
8888

8989
LLAMA_EXAMPLE_COUNT,
90+
LLAMA_EXAMPLE_FINETUNE,
9091
};
9192

9293
enum common_sampler_type {
@@ -305,6 +306,9 @@ struct common_params {
305306
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
306307
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
307308
std::string logits_file = ""; // file for saving *all* logits // NOLINT
309+
std::string dataset_format = "text"; // "text" | "parquet"
310+
std::string parquet_path; // path to Parquet
311+
std::string tokens_column = "tokens"; // name column list<int32>
308312

309313
std::vector<std::string> in_files; // all input files
310314
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)

examples/training/CMakeLists.txt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
set(TARGET llama-finetune)
2-
add_executable(${TARGET} finetune.cpp)
2+
add_executable(${TARGET} finetune.cpp parquet_dataset.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
4-
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
4+
5+
6+
if(LLAMA_PARQUET)
7+
find_package(Arrow REQUIRED)
8+
find_package(Parquet REQUIRED)
9+
add_definitions(-DLLAMA_PARQUET)
10+
endif()
11+
12+
if(LLAMA_PARQUET)
13+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} Arrow::arrow_shared Parquet::parquet_shared)
14+
else()
15+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
16+
endif()
517
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/training/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,21 @@ Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory.
88

99
Proof of concept:
1010

11+
With load data from common file:
12+
1113
``` sh
1214
export model_name=llama_3.2-1b && export quantization=f32
1315
./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
1416
./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf
1517
```
1618

19+
With load data from parquet (without batching):
20+
21+
You need install arrow package and build with LLAMA_PARQUET=ON
22+
23+
``` sh
24+
mkdir build; cmake -DLLAMA_PARQUET=ON .. ; make
25+
export model_name=llama_3.2-1b && export quantization=f32
26+
./build/bin/llama-finetune -ngl 999 --dataset-format parquet --parquet-path parquet.parquet --tokens-column tokens --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512
27+
```
1728
The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs.

examples/training/finetune.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <ctime>
1010
#include <vector>
1111

12+
#include "parquet_dataset.h"
13+
1214
#if defined(_MSC_VER)
1315
#pragma warning(disable: 4244 4267) // possible loss of data
1416
#endif
@@ -55,9 +57,23 @@ int main(int argc, char ** argv) {
5557
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5658
}
5759

58-
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
59-
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);
60-
60+
std::vector<llama_token> tokens;
61+
#ifdef LLAMA_PARQUET
62+
if (params.dataset_format == "text") {
63+
#endif
64+
tokens = common_tokenize(pctx, params.prompt, true); //load from text file
65+
#ifdef LLAMA_PARQUET
66+
}
67+
else if (params.dataset_format == "parquet") {
68+
tokens = load_parquet_dataset(params.parquet_path, params.tokens_column);
69+
if (tokens.empty()) {
70+
LOG_ERR("No tokens in %s, or column %s not found/invalid", params.parquet_path.c_str(), params.tokens_column.c_str());
71+
return 1;
72+
}
73+
LOG_INF("Loaded %zu tokens from Parquet", tokens.size());
74+
}
75+
#endif
76+
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);
6177
struct lr_opt & lr = params.lr;
6278
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
6379
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.min_epochs,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#ifdef LLAMA_PARQUET
2+
#include "parquet_dataset.h"
3+
#include <arrow/api.h>
4+
#include <arrow/io/file.h>
5+
#include <parquet/arrow/reader.h>
6+
#include "llama-impl.h"
7+
8+
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column) {
9+
arrow::MemoryPool *pool = arrow::default_memory_pool();
10+
std::shared_ptr<arrow::io::RandomAccessFile> infile;
11+
PARQUET_ASSIGN_OR_THROW(infile, arrow::io::ReadableFile::Open(path));
12+
arrow::Result<std::unique_ptr<parquet::arrow::FileReader>> reader_raw;
13+
PARQUET_ASSIGN_OR_THROW(reader_raw, parquet::arrow::OpenFile(infile, pool));
14+
15+
std::unique_ptr<parquet::arrow::FileReader> reader = std::move(reader_raw.ValueUnsafe());
16+
std::shared_ptr<arrow::Table> table;
17+
PARQUET_THROW_NOT_OK(reader->ReadTable(&table));
18+
19+
auto field = table->schema()->GetFieldByName(column);
20+
if (!field || !field->type()->Equals(arrow::list(arrow::int32()))) {
21+
LLAMA_LOG_ERROR("Parquet column '%s' missing or not list<int32>", column.c_str());
22+
return {};
23+
}
24+
25+
auto col = table->GetColumnByName(column);
26+
std::vector<llama_token> tokens;
27+
for (int chunk = 0; chunk < col->num_chunks(); ++chunk) {
28+
auto list_arr = std::static_pointer_cast<arrow::ListArray>(col->chunk(chunk));
29+
auto values_arr = std::static_pointer_cast<arrow::Int32Array>(list_arr->values());
30+
// get raw offsets (int32_t or int64_t based on ListArray template)
31+
const auto *offsets = list_arr->raw_value_offsets();
32+
// offsets length = list_arr->length() + 1
33+
int64_t values_length = values_arr->length();
34+
for (int64_t i = 0; i < list_arr->length(); ++i) {
35+
int64_t start = offsets[i];
36+
int64_t end = offsets[i + 1];
37+
// Clamp end
38+
if (start < 0) start = 0;
39+
if (end > values_length) end = values_length;
40+
for (int64_t j = start; j < end; ++j) {
41+
tokens.push_back(static_cast<llama_token>(values_arr->Value(j)));
42+
}
43+
}
44+
}
45+
return tokens;
46+
}
47+
#endif // LLAMA_PARQUET
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#ifndef PARQUET_DATASET_H
2+
#define PARQUET_DATASET_H
3+
#include <string>
4+
#include <vector>
5+
#include "llama.h"
6+
7+
#ifdef LLAMA_PARQUET
8+
std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column);
9+
#endif
10+
#endif //

0 commit comments

Comments
 (0)