|
| 1 | +#ifdef LLAMA_PARQUET |
| 2 | +#include "parquet_dataset.h" |
| 3 | +#include <arrow/api.h> |
| 4 | +#include <arrow/io/file.h> |
| 5 | +#include <parquet/arrow/reader.h> |
| 6 | +#include "llama-impl.h" |
| 7 | + |
| 8 | +std::vector<llama_token> load_parquet_dataset(const std::string &path, const std::string &column) { |
| 9 | + arrow::MemoryPool *pool = arrow::default_memory_pool(); |
| 10 | + std::shared_ptr<arrow::io::RandomAccessFile> infile; |
| 11 | + PARQUET_ASSIGN_OR_THROW(infile, arrow::io::ReadableFile::Open(path)); |
| 12 | + arrow::Result<std::unique_ptr<parquet::arrow::FileReader>> reader_raw; |
| 13 | + PARQUET_ASSIGN_OR_THROW(reader_raw, parquet::arrow::OpenFile(infile, pool)); |
| 14 | + |
| 15 | + std::unique_ptr<parquet::arrow::FileReader> reader = std::move(reader_raw.ValueUnsafe()); |
| 16 | + std::shared_ptr<arrow::Table> table; |
| 17 | + PARQUET_THROW_NOT_OK(reader->ReadTable(&table)); |
| 18 | + |
| 19 | + auto field = table->schema()->GetFieldByName(column); |
| 20 | + if (!field || !field->type()->Equals(arrow::list(arrow::int32()))) { |
| 21 | + LLAMA_LOG_ERROR("Parquet column '%s' missing or not list<int32>", column.c_str()); |
| 22 | + return {}; |
| 23 | + } |
| 24 | + |
| 25 | + auto col = table->GetColumnByName(column); |
| 26 | + std::vector<llama_token> tokens; |
| 27 | + for (int chunk = 0; chunk < col->num_chunks(); ++chunk) { |
| 28 | + auto list_arr = std::static_pointer_cast<arrow::ListArray>(col->chunk(chunk)); |
| 29 | + auto values_arr = std::static_pointer_cast<arrow::Int32Array>(list_arr->values()); |
| 30 | + // get raw offsets (int32_t or int64_t based on ListArray template) |
| 31 | + const auto *offsets = list_arr->raw_value_offsets(); |
| 32 | + // offsets length = list_arr->length() + 1 |
| 33 | + int64_t values_length = values_arr->length(); |
| 34 | + for (int64_t i = 0; i < list_arr->length(); ++i) { |
| 35 | + int64_t start = offsets[i]; |
| 36 | + int64_t end = offsets[i + 1]; |
| 37 | + // Clamp end |
| 38 | + if (start < 0) start = 0; |
| 39 | + if (end > values_length) end = values_length; |
| 40 | + for (int64_t j = start; j < end; ++j) { |
| 41 | + tokens.push_back(static_cast<llama_token>(values_arr->Value(j))); |
| 42 | + } |
| 43 | + } |
| 44 | + } |
| 45 | + return tokens; |
| 46 | +} |
| 47 | +#endif // LLAMA_PARQUET |
0 commit comments