diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 470dc3d916b90..a096eef8f6386 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -12,6 +12,16 @@ #include #include #include +#include +#include + +#ifdef _WIN32 +#include +#define getpid GetCurrentProcessId +#else +#include +#include +#endif struct quant_option { std::string name; @@ -76,6 +86,36 @@ static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets"; static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; +// Check if two paths refer to the same physical file +// Returns true if they are the same file (including via hardlinks or symlinks) +static bool same_file(const std::string & path_a, const std::string & path_b) { + std::error_code ec_a, ec_b; + + // First try using std::filesystem to resolve canonical paths + auto canonical_a = std::filesystem::weakly_canonical(path_a, ec_a); + auto canonical_b = std::filesystem::weakly_canonical(path_b, ec_b); + + if (!ec_a && !ec_b) { + // If both paths were successfully canonicalized, compare them + if (canonical_a == canonical_b) { + return true; + } + } + +#ifndef _WIN32 + // On Unix-like systems, also check using stat() to handle hardlinks + struct stat stat_a, stat_b; + if (stat(path_a.c_str(), &stat_a) == 0 && stat(path_b.c_str(), &stat_b) == 0) { + // Same file if device and inode match + if (stat_a.st_dev == stat_b.st_dev && stat_a.st_ino == stat_b.st_ino) { + return true; + } + } +#endif + + return false; +} + static bool striequals(const char * a, const char * b) { while (*a && *b) { if (std::tolower(*a) != std::tolower(*b)) { @@ -119,7 +159,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp static void usage(const char * executable) { printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); - printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); + printf(" [--inplace] [--overwrite] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); @@ -135,6 +175,8 @@ static void usage(const char * executable) { printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); + printf(" --inplace: Allow in-place quantization (input == output). Uses temp file + atomic rename for safety.\n"); + printf(" --overwrite: Allow overwriting existing output file (still uses temp file for safety).\n"); printf("Note: --include-weights and --exclude-weights cannot be used together\n"); printf("\nAllowed quantization types:\n"); for (const auto & it : QUANT_OPTIONS) { @@ -453,6 +495,8 @@ int main(int argc, char ** argv) { std::vector kv_overrides; std::vector tensor_types; std::vector prune_layers; + bool allow_inplace = false; + bool allow_overwrite = false; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -511,6 +555,10 @@ int main(int argc, char ** argv) { } } else if (strcmp(argv[arg_idx], "--keep-split") == 0) { params.keep_split = true; + } else if (strcmp(argv[arg_idx], "--inplace") == 0) { + allow_inplace = true; + } else if (strcmp(argv[arg_idx], "--overwrite") == 0) { + allow_overwrite = true; } else { usage(argv[0]); } @@ -645,6 +693,62 @@ int main(int argc, char ** argv) { print_build_info(); + // Check if input and output refer to the same physical file + // This prevents catastrophic data loss from truncating the input while reading it + if (same_file(fname_inp, fname_out)) { + if (!allow_inplace) { + fprintf(stderr, "\n==========================================================================================================\n"); + fprintf(stderr, "ERROR: Input and output files are the same: '%s'\n", fname_inp.c_str()); + fprintf(stderr, "This would truncate the input file while reading it, causing data corruption and SIGBUS errors.\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "Solutions:\n"); + fprintf(stderr, " 1. Specify a different output filename\n"); + fprintf(stderr, " 2. Use --inplace flag to enable safe in-place quantization (uses temp file + atomic rename)\n"); + fprintf(stderr, "==========================================================================================================\n\n"); + llama_backend_free(); + return 1; + } + fprintf(stderr, "%s: WARNING: in-place quantization detected, using temporary file for safety\n", __func__); + } + + // Check if output file already exists (unless overwrite is allowed) + if (!allow_overwrite && !allow_inplace) { + std::error_code ec; + if (std::filesystem::exists(fname_out, ec)) { + fprintf(stderr, "\n==========================================================================================================\n"); + fprintf(stderr, "ERROR: Output file already exists: '%s'\n", fname_out.c_str()); + fprintf(stderr, "Use --overwrite flag to allow overwriting existing files.\n"); + fprintf(stderr, "==========================================================================================================\n\n"); + llama_backend_free(); + return 1; + } + } + + // Prepare actual output path (may be temporary for in-place or overwrite mode) + std::string fname_out_actual = fname_out; + std::string fname_out_temp; + bool use_temp_file = allow_inplace && same_file(fname_inp, fname_out); + + if (use_temp_file || allow_overwrite) { + // Create temp file in the same directory as output for atomic rename + std::filesystem::path out_path(fname_out); + std::filesystem::path out_dir = out_path.parent_path(); + if (out_dir.empty()) { + out_dir = "."; + } + + // Generate temp filename + fname_out_temp = (out_dir / ("." + out_path.filename().string() + ".tmp.XXXXXX")).string(); + + // Create the temp file safely + // Note: mkstemp would be safer but requires char* and creates the file + // For simplicity, we'll use a simpler approach with PID + fname_out_temp = (out_dir / ("." + out_path.filename().string() + ".tmp." + std::to_string(getpid()))).string(); + fname_out_actual = fname_out_temp; + + fprintf(stderr, "%s: using temporary file: '%s'\n", __func__, fname_out_actual.c_str()); + } + fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str()); if (params.nthread > 0) { fprintf(stderr, " using %d threads", params.nthread); @@ -659,14 +763,42 @@ int main(int argc, char ** argv) { { const int64_t t_start_us = llama_time_us(); - if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ¶ms)) { + if (llama_model_quantize(fname_inp.c_str(), fname_out_actual.c_str(), ¶ms)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); + + // Clean up temp file on failure + if (!fname_out_temp.empty()) { + std::error_code ec; + std::filesystem::remove(fname_out_temp, ec); + } + + llama_backend_free(); return 1; } t_quantize_us = llama_time_us() - t_start_us; } + // If we used a temp file, atomically rename it to the final output + if (!fname_out_temp.empty()) { + fprintf(stderr, "%s: atomically moving temp file to final output\n", __func__); + std::error_code ec; + + // On POSIX systems, rename() is atomic when both paths are on the same filesystem + std::filesystem::rename(fname_out_temp, fname_out, ec); + + if (ec) { + fprintf(stderr, "%s: failed to rename temp file '%s' to '%s': %s\n", + __func__, fname_out_temp.c_str(), fname_out.c_str(), ec.message().c_str()); + + // Try to clean up temp file + std::filesystem::remove(fname_out_temp, ec); + + llama_backend_free(); + return 1; + } + } + // report timing { const int64_t t_main_end_us = llama_time_us();