From 65cdb5156e937b37eee7a30eb966746f0a01d4d9 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Tue, 13 Jun 2023 20:34:00 +0400 Subject: [PATCH 1/6] Implement on-the-fly quantization --- rwkv.cpp | 111 +++++++++++++++++++++++++++----- rwkv.h | 3 +- rwkv/rwkv_cpp_model.py | 4 +- rwkv/rwkv_cpp_shared_library.py | 7 +- tests/test_context_cloning.c | 4 +- tests/test_tiny_rwkv.c | 2 +- 6 files changed, 107 insertions(+), 24 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index b99ac6a0..cbcea87a 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -363,27 +363,85 @@ bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = return true; } -bool rwkv_fread_ggml_tensor_data(FILE * file, const struct rwkv_tensor_header & header, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { +bool rwkv_should_be_quantized(const ggml_type source_type, const ggml_type target_type, const std::string & name, const uint32_t dim_count) { + // Quantize only 2D tensors, except embedding and head matrices. + // Embedding and head take not too much space, especially in bigger models; + // but they significantly increase perplexity when quantized. + return target_type != GGML_TYPE_COUNT && + target_type != source_type && + (source_type == GGML_TYPE_F32 || source_type == GGML_TYPE_F16) && + dim_count == 2 && + name != "emb.weight" && + name != "head.weight"; +} + +bool rwkv_fread_ggml_tensor_data( + FILE * file, + const struct rwkv_tensor_header & header, + struct ggml_context * ctx, + std::string & name, + struct ggml_tensor *& tensor, + const ggml_type target_type +) { RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_string(file, header.key_length, name), "Failed to read tensor name"); enum ggml_type ggml_type = rwkv_type_to_ggml[header.data_type]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); - tensor = header.dim_count == 1 - ? ggml_new_tensor_1d(ctx, ggml_type, header.width) - : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); + if (rwkv_should_be_quantized(ggml_type, target_type, name, header.dim_count)) { + // TODO Remove + fprintf(stderr, "Quantizing %s on the fly\n", name.c_str()); + + size_t buffer_size_bytes = header.dim_count == 1 + ? rwkv_tensor_size(ggml_type, header.width) + : rwkv_tensor_size(ggml_type, header.width, header.height); + + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, target_type, header.width) + : ggml_new_tensor_2d(ctx, target_type, header.width, header.height); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + // TODO Make safer (free on return) + char * buffer = (char *) malloc(buffer_size_bytes); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer), "Failed to read tensor data from %s", name.c_str()); + + // Quantization works only with FP32 values + if (header.data_type == TYPE_F16) { + float * float_buffer = (float *) malloc(buffer_size_bytes * 2); + + ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer, (float *) float_buffer, ggml_nelements(tensor)); + + free(buffer); + + buffer = (char *) float_buffer; + } + + int64_t histogram[16] {}; + + ggml_quantize_chunk(target_type, (const float *) buffer, tensor->data, 0, ggml_nelements(tensor), histogram); + + free(buffer); + } else { + tensor = header.dim_count == 1 + ? ggml_new_tensor_1d(ctx, ggml_type, header.width) + : ggml_new_tensor_2d(ctx, ggml_type, header.width, header.height); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); - ggml_set_name(tensor, name.c_str()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); + ggml_set_name(tensor, name.c_str()); + + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); + } - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, ggml_nbytes(tensor), tensor->data), "Failed to read tensor data from %s", name.c_str()); return true; } -bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor) { +bool rwkv_fread_ggml_tensor(FILE * file, struct ggml_context * ctx, std::string & name, struct ggml_tensor *& tensor, const ggml_type target_type) { struct rwkv_tensor_header header; RWKV_ENSURE_OR_FALSE_MSG(rwkv_fread_tensor_header(file, header), "Invalid tensor header"); - return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor); + return rwkv_fread_ggml_tensor_data(file, header, ctx, name, tensor, target_type); } bool rwkv_fwrite_tensor(FILE * file, const struct rwkv_tensor & tensor) { @@ -1115,7 +1173,7 @@ struct rwkv_file { } }; -bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance) { +bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & instance, const ggml_type target_type) { struct stat file_stat; struct rwkv_model model; struct rwkv_ggml_context ctx; @@ -1140,7 +1198,14 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + if (rwkv_should_be_quantized(rwkv_type_to_ggml[tensor_header.data_type], target_type, name, tensor_header.dim_count)) { + // TODO Remove + fprintf(stderr, "Allocating less bytes for quantized tensor\n"); + + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, target_type, tensor_header.width, tensor_header.height); + } else { + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); + } if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { ffn_key_size = tensor_header.height; @@ -1156,7 +1221,7 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst struct ggml_tensor * tensor; while ((size_t) ftell(file.file) < (size_t) file_stat.st_size) { - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor), "Failed to read model params"); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_ggml_tensor(file.file, ctx.ctx, name, tensor, target_type), "Failed to read model params"); parameters[std::move(name)] = tensor; } } @@ -1258,12 +1323,25 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr instance(new(std::nothrow) struct rwkv_instance()); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_CTX | RWKV_ERROR_ALLOC, instance, "Failed to allocate instance"); - RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get())); + RWKV_ENSURE_OR_NULL(rwkv_instance_from_file(file_path, *instance.get(), target_type)); return rwkv_new_context_impl(instance, n_threads); } @@ -1438,10 +1516,10 @@ void rwkv_free(struct rwkv_context * ctx) { std::unique_ptr rwkv_ctx(ctx); } -bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * type_name) { +bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const char * format_name) { global_last_error = RWKV_ERROR_NONE; - enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(type_name)]; + enum ggml_type out_type = rwkv_type_to_ggml[rwkv_type_from_string(format_name)]; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, ggml_is_quantized(out_type), "Unsupported output data type (%s)", rwkv_type_to_string[rwkv_type_from_ggml[out_type]]); RWKV_MSG("Loading model from '%s'\n", in_path); @@ -1549,6 +1627,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // Quantize only 2D tensors, except embedding and head matrices. // Embedding and head take not too much space, especially in bigger models; // but they significantly increase perplexity when quantized. + // TODO Use rwkv_should_be_quantized if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { RWKV_MSG("quantizing... "); diff --git a/rwkv.h b/rwkv.h index 8327425e..dc4c5e2f 100644 --- a/rwkv.h +++ b/rwkv.h @@ -85,9 +85,10 @@ extern "C" { // Loads the model from a file and prepares it for inference. // Returns NULL on any error. + // TODO Split for compatibility and document // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const char * target_format_name); // Creates a new context from an existing one. // This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times. diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index b38c7ce2..79262ff6 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -9,12 +9,14 @@ class RWKVModel: PyTorch wrapper around rwkv.cpp model. """ + # TODO Document target format parameter def __init__( self, shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, model_path: str, thread_count: int = max(1, multiprocessing.cpu_count() // 2), gpu_layers_count: int = 0, + target_format_name: str = '' ): """ Loads the model and prepares it for inference. @@ -36,7 +38,7 @@ def __init__( self._library = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, target_format_name) if gpu_layers_count > 0: self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index a38cbbb2..a6b950d9 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -37,7 +37,7 @@ def __init__(self, shared_library_path: str): self.library = ctypes.cdll.LoadLibrary(shared_library_path) - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32] + self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_char_p] self.library.rwkv_init_from_file.restype = ctypes.c_void_p self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] @@ -70,7 +70,8 @@ def __init__(self, shared_library_path: str): self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p - def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVContext: + # TODO Document target format parameter + def rwkv_init_from_file(self, model_file_path: str, thread_count: int, target_format_name: str = '') -> RWKVContext: """ Loads the model from a file and prepares it for inference. Throws an exception in case of any error. Error messages would be printed to stderr. @@ -83,7 +84,7 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int) -> RWKVCo Count of threads to use, must be positive. """ - ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count)) + ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), target_format_name.encode('utf-8')) assert ptr is not None, 'rwkv_init_from_file failed, check stderr' diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index eb0f7c4c..3ff12556 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -1,11 +1,11 @@ -#include +#include "rwkv.h" #include #include #include int main() { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2, ""); if (!ctx) { enum rwkv_error_flags error = rwkv_get_last_error(NULL); diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index adb0de77..8a687873 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -26,7 +26,7 @@ void test_model(const char * model_path, const float * expected_logits, const float max_diff) { fprintf(stderr, "Testing %s\n", model_path); - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); + struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, ""); enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); #ifdef GGML_USE_CUBLAS From dca26e9f27a238b8750603b7ad4cd35a39258839 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Wed, 14 Jun 2023 18:46:53 +0400 Subject: [PATCH 2/6] Resolve TODO items --- rwkv.cpp | 86 +++++++++++++++++++-------------- rwkv.h | 46 +++++++++++++++++- rwkv/rwkv_cpp_model.py | 9 ++-- rwkv/rwkv_cpp_shared_library.py | 52 +++++++++++++++++--- tests/test_context_cloning.c | 2 +- tests/test_tiny_rwkv.c | 2 +- 6 files changed, 146 insertions(+), 51 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index cbcea87a..b2236a08 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -363,13 +363,16 @@ bool rwkv_fread_tensor(FILE * file, struct rwkv_tensor & output, void * buffer = return true; } +// Returns true, if a tensor with specified type, name and dimension count shound be quantized to target_type. +// Returns false, if a tensor should be left as-is. bool rwkv_should_be_quantized(const ggml_type source_type, const ggml_type target_type, const std::string & name, const uint32_t dim_count) { // Quantize only 2D tensors, except embedding and head matrices. - // Embedding and head take not too much space, especially in bigger models; + // Embedding and head take little space, especially in bigger models; // but they significantly increase perplexity when quantized. - return target_type != GGML_TYPE_COUNT && + return (source_type == GGML_TYPE_F32 || source_type == GGML_TYPE_F16) && + target_type != GGML_TYPE_COUNT && target_type != source_type && - (source_type == GGML_TYPE_F32 || source_type == GGML_TYPE_F16) && + ggml_is_quantized(target_type) && dim_count == 2 && name != "emb.weight" && name != "head.weight"; @@ -389,9 +392,6 @@ bool rwkv_fread_ggml_tensor_data( RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_UNSUPPORTED, ggml_type != GGML_TYPE_UNKNOWN, "Unsupported tensor data type %s from %s", rwkv_type_to_string[header.data_type], name.c_str()); if (rwkv_should_be_quantized(ggml_type, target_type, name, header.dim_count)) { - // TODO Remove - fprintf(stderr, "Quantizing %s on the fly\n", name.c_str()); - size_t buffer_size_bytes = header.dim_count == 1 ? rwkv_tensor_size(ggml_type, header.width) : rwkv_tensor_size(ggml_type, header.width, header.height); @@ -403,27 +403,24 @@ bool rwkv_fread_ggml_tensor_data( RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, tensor, "Failed to allocate tensor"); ggml_set_name(tensor, name.c_str()); - // TODO Make safer (free on return) - char * buffer = (char *) malloc(buffer_size_bytes); + std::unique_ptr buffer(new(std::nothrow) char[buffer_size_bytes]); - RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer), "Failed to read tensor data from %s", name.c_str()); + RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer.get()), "Failed to read tensor data from %s", name.c_str()); // Quantization works only with FP32 values if (header.data_type == TYPE_F16) { - float * float_buffer = (float *) malloc(buffer_size_bytes * 2); - - ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer, (float *) float_buffer, ggml_nelements(tensor)); + std::unique_ptr float_buffer(new(std::nothrow) char[buffer_size_bytes * 2]); - free(buffer); + ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer.get(), (float *) float_buffer.get(), ggml_nelements(tensor)); - buffer = (char *) float_buffer; + buffer.reset(float_buffer.release()); } int64_t histogram[16] {}; - ggml_quantize_chunk(target_type, (const float *) buffer, tensor->data, 0, ggml_nelements(tensor), histogram); + ggml_quantize_chunk(target_type, (const float *) buffer.get(), tensor->data, 0, ggml_nelements(tensor), histogram); - free(buffer); + buffer.reset(); } else { tensor = header.dim_count == 1 ? ggml_new_tensor_1d(ctx, ggml_type, header.width) @@ -1198,14 +1195,13 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst RWKV_ASSERT_NULL_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_string(file.file, tensor_header.key_length, name), "Failed to read tensor name"); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_FILE | RWKV_ERROR_FILE_READ, fseek(file.file, rwkv_tensor_size(tensor_header), SEEK_CUR) == 0, "Failed to read tensor data"); - if (rwkv_should_be_quantized(rwkv_type_to_ggml[tensor_header.data_type], target_type, name, tensor_header.dim_count)) { - // TODO Remove - fprintf(stderr, "Allocating less bytes for quantized tensor\n"); + enum ggml_type source_type = rwkv_type_to_ggml[tensor_header.data_type]; - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, target_type, tensor_header.width, tensor_header.height); - } else { - rwkv_ctx_size_add_tensor(ctx_size, 1, 0, tensor_header); - } + enum ggml_type in_memory_type = rwkv_should_be_quantized(source_type, target_type, name, tensor_header.dim_count) + ? target_type + : source_type; + + rwkv_ctx_size_add_tensor(ctx_size, 1, 0, in_memory_type, tensor_header.width, tensor_header.height); if (ffn_key_size == 0 && name == "blocks.0.ffn.key.weight") { ffn_key_size = tensor_header.height; @@ -1323,20 +1319,40 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr instance(new(std::nothrow) struct rwkv_instance()); @@ -1624,11 +1640,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); - // Quantize only 2D tensors, except embedding and head matrices. - // Embedding and head take not too much space, especially in bigger models; - // but they significantly increase perplexity when quantized. - // TODO Use rwkv_should_be_quantized - if ((header.data_type == TYPE_F32 || header.data_type == TYPE_F16) && header.dim_count == 2 && name != "emb.weight" && name != "head.weight") { + if (rwkv_should_be_quantized(rwkv_type_to_ggml[header.data_type], out_type, name, header.dim_count)) { RWKV_MSG("quantizing... "); size_t nelements = (size_t) header.width * (size_t) header.height; diff --git a/rwkv.h b/rwkv.h index dc4c5e2f..3ea021e4 100644 --- a/rwkv.h +++ b/rwkv.h @@ -83,12 +83,54 @@ extern "C" { // - ctx: the context the retrieve the error for, or NULL for the global error. RWKV_API enum rwkv_error_flags rwkv_get_last_error(struct rwkv_context * ctx); + enum rwkv_init_from_file_option_key { + // Sets target format of model parameters. + // + // If an FP16 or FP32 model is being loaded, and this option is set, + // parameters will be quantized just-in-time into the specified format. + // If an already quantized model is being loaded, value of this option is ignored. + // The function will not read the whole model file at once, but will do quantization tensor-by-tensor; + // it is safe to load big models which will fit into RAM when quantized. + // Use of this option will introduce significant one-time delay when loading the model. + // + // Intended use-case is to have only FP16 model on disk, while not wasting + // the disk space on models of all available quantized formats. + // + // Allowed values: + // - Q4_0 + // - Q4_1 + // - Q5_0 + // - Q5_1 + // - Q8_0 + RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, + // Do not use this as an actual option key. + RWKV_INIT_FROM_FILE_OPTION_COUNT + }; + + struct rwkv_init_from_file_option { + // Key of the option. + enum rwkv_init_from_file_option_key key; + // Value of the option as a NULL-terminated, UTF-8 encoded string. + char * value; + }; + // Loads the model from a file and prepares it for inference. + // Loading behavior can be customized with options, but none of them are required. + // Function behavior when multiple options with the same key are specified is undefined. // Returns NULL on any error. - // TODO Split for compatibility and document // - model_file_path: path to model file in ggml format. // - n_threads: count of threads to use, must be positive. - RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads, const char * target_format_name); + // - options: array of options. Passing NULL is the same as setting option_count to 0. + // - option_count: size of the options array. + RWKV_API struct rwkv_context * rwkv_init_from_file_ex( + const char * model_file_path, + const uint32_t n_threads, + const struct rwkv_init_from_file_option * options, + const size_t option_count + ); + + // Same as rwkv_init_from_file_ex, but passing an empty array of options. + RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads); // Creates a new context from an existing one. // This can allow you to run multiple rwkv_eval's in parallel, without having to load a single model multiple times. diff --git a/rwkv/rwkv_cpp_model.py b/rwkv/rwkv_cpp_model.py index 79262ff6..fd775731 100644 --- a/rwkv/rwkv_cpp_model.py +++ b/rwkv/rwkv_cpp_model.py @@ -2,21 +2,20 @@ import torch import multiprocessing import rwkv_cpp_shared_library -from typing import Tuple, Optional +from typing import Dict, Tuple, Optional class RWKVModel: """ PyTorch wrapper around rwkv.cpp model. """ - # TODO Document target format parameter def __init__( self, shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, model_path: str, thread_count: int = max(1, multiprocessing.cpu_count() // 2), gpu_layers_count: int = 0, - target_format_name: str = '' + options: Optional[Dict[rwkv_cpp_shared_library.RWKVInitFromFileOptionKey, str]] = None ): """ Loads the model and prepares it for inference. @@ -30,6 +29,8 @@ def __init__( Path to RWKV model file in ggml format. thread_count : int Thread count to use. If not set, defaults to CPU count / 2. + options : Optional[Dict[RWKVInitFromFileOptionKey, str]] + Options passed to rwkv_init_from_file_ex. """ assert os.path.isfile(model_path), f'{model_path} is not a file' @@ -38,7 +39,7 @@ def __init__( self._library = shared_library - self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, target_format_name) + self._ctx = self._library.rwkv_init_from_file(model_path, thread_count, options) if gpu_layers_count > 0: self._library.rwkv_gpu_offload_layers(self._ctx, gpu_layers_count) diff --git a/rwkv/rwkv_cpp_shared_library.py b/rwkv/rwkv_cpp_shared_library.py index a6b950d9..562bf7a1 100644 --- a/rwkv/rwkv_cpp_shared_library.py +++ b/rwkv/rwkv_cpp_shared_library.py @@ -2,7 +2,8 @@ import sys import ctypes import pathlib -from typing import Optional +import enum +from typing import Dict, Optional QUANTIZED_FORMAT_NAMES = ( 'Q4_0', @@ -14,6 +15,29 @@ P_FLOAT = ctypes.POINTER(ctypes.c_float) +class RWKVInitFromFileOptionKey(enum.Enum): + # Sets target format of model parameters. + # + # If an FP16 or FP32 model is being loaded, and this option is set, + # parameters will be quantized just-in-time into the specified format. + # If an already quantized model is being loaded, value of this option is ignored. + # The function will not read the whole model file at once, but will do quantization tensor-by-tensor; + # it is safe to load big models which will fit into RAM when quantized. + # Use of this option will introduce significant one-time delay when loading the model. + # + # Intended use-case is to have only FP16 model on disk, while not wasting + # the disk space on models of all available quantized formats. + # + # For allowed values, see QUANTIZED_FORMAT_NAMES. + TARGET_FORMAT_NAME = 0 + +class RWKVInitFromFileOption(ctypes.Structure): + + _fields_ = [ + ('key', ctypes.c_int), + ('value', ctypes.c_char_p) + ] + class RWKVContext: def __init__(self, ptr: ctypes.pointer): @@ -37,8 +61,8 @@ def __init__(self, shared_library_path: str): self.library = ctypes.cdll.LoadLibrary(shared_library_path) - self.library.rwkv_init_from_file.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.c_char_p] - self.library.rwkv_init_from_file.restype = ctypes.c_void_p + self.library.rwkv_init_from_file_ex.argtypes = [ctypes.c_char_p, ctypes.c_uint32, ctypes.POINTER(RWKVInitFromFileOption), ctypes.c_size_t] + self.library.rwkv_init_from_file_ex.restype = ctypes.c_void_p self.library.rwkv_gpu_offload_layers.argtypes = [ctypes.c_void_p, ctypes.c_uint32] self.library.rwkv_gpu_offload_layers.restype = ctypes.c_bool @@ -70,10 +94,10 @@ def __init__(self, shared_library_path: str): self.library.rwkv_get_system_info_string.argtypes = [] self.library.rwkv_get_system_info_string.restype = ctypes.c_char_p - # TODO Document target format parameter - def rwkv_init_from_file(self, model_file_path: str, thread_count: int, target_format_name: str = '') -> RWKVContext: + def rwkv_init_from_file(self, model_file_path: str, thread_count: int, options: Optional[Dict[RWKVInitFromFileOptionKey, str]] = None) -> RWKVContext: """ Loads the model from a file and prepares it for inference. + Loading behavior can be customized with options, but none of them are required. Throws an exception in case of any error. Error messages would be printed to stderr. Parameters @@ -82,9 +106,25 @@ def rwkv_init_from_file(self, model_file_path: str, thread_count: int, target_fo Path to model file in ggml format. thread_count : int Count of threads to use, must be positive. + options : Optional[Dict[RWKVInitFromFileOptionKey, str]] + Options passed to rwkv_init_from_file_ex. """ - ptr = self.library.rwkv_init_from_file(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), target_format_name.encode('utf-8')) + options_count = 0 + options_ptr = None + + if options is not None and len(options) > 0: + options_count = len(options) + options_ptr = (RWKVInitFromFileOption * options_count)() + + i = 0 + for k, v in options.items(): + options_ptr[i].key = k.value + options_ptr[i].value = v.encode('utf-8') + + i += 1 + + ptr = self.library.rwkv_init_from_file_ex(model_file_path.encode('utf-8'), ctypes.c_uint32(thread_count), options_ptr, options_count) assert ptr is not None, 'rwkv_init_from_file failed, check stderr' diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index 3ff12556..7346c45b 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -5,7 +5,7 @@ #include int main() { - struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2, ""); + struct rwkv_context * ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32.bin", 2); if (!ctx) { enum rwkv_error_flags error = rwkv_get_last_error(NULL); diff --git a/tests/test_tiny_rwkv.c b/tests/test_tiny_rwkv.c index 8a687873..adb0de77 100644 --- a/tests/test_tiny_rwkv.c +++ b/tests/test_tiny_rwkv.c @@ -26,7 +26,7 @@ void test_model(const char * model_path, const float * expected_logits, const float max_diff) { fprintf(stderr, "Testing %s\n", model_path); - struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS, ""); + struct rwkv_context * model = rwkv_init_from_file(model_path, N_THREADS); enum rwkv_error_flags error = rwkv_get_last_error(NULL); ASSERT(error == 0, "Unexpected error %d", error); #ifdef GGML_USE_CUBLAS From 7a13fd26673d845ec6f117aedfec6423b8c2e041 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Wed, 14 Jun 2023 19:17:40 +0400 Subject: [PATCH 3/6] Fix error code --- rwkv.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rwkv.cpp b/rwkv.cpp index b2236a08..cfe7baf5 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -1349,7 +1349,7 @@ struct rwkv_context * rwkv_init_from_file_ex( break; default: - RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ARGS | RWKV_ERROR_DATA_TYPE, false, "Invalid option key %d", options[i].key); + RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ARGS, false, "Invalid option key %d", options[i].key); break; } From 4d27fa8bfa71a393e74df01a9f92b48f3d7512c4 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Wed, 14 Jun 2023 19:18:30 +0400 Subject: [PATCH 4/6] Reformat code --- rwkv.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/rwkv.cpp b/rwkv.cpp index cfe7baf5..1bf3efcb 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -484,7 +484,7 @@ struct rwkv_model { struct ggml_tensor * ln0_weight; struct ggml_tensor * ln0_bias; - std::unique_ptr layers; + std::unique_ptr layers; struct ggml_tensor * ln_out_weight; struct ggml_tensor * ln_out_bias; @@ -635,9 +635,9 @@ struct rwkv_context { // Reused by all graphs. struct rwkv_ggml_context ctx; struct ggml_tensor * input_state; - std::unique_ptr input_layers; + std::unique_ptr input_layers; struct ggml_tensor * output_state; - std::unique_ptr output_layers; + std::unique_ptr output_layers; struct ggml_tensor * logits; uint32_t n_threads; @@ -665,7 +665,7 @@ bool rwkv_set_params(struct rwkv_model & model, F callback) { RWKV_ENSURE_OR_FALSE(callback("blocks.0.ln0.bias", model.ln0_bias)); uint32_t n_layer = model.header.n_layer; - std::unique_ptr layers(new(std::nothrow) struct rwkv_layer [n_layer]); + std::unique_ptr layers(new(std::nothrow) struct rwkv_layer[n_layer]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, layers.get(), "Failed to allocate model layers"); model.layers = std::move(layers); @@ -1223,7 +1223,7 @@ bool rwkv_instance_from_file(const char * file_path, struct rwkv_instance & inst } std::unordered_map & parameters_ref = parameters; - RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model, [&](const char * key, struct ggml_tensor *& dest) { + RWKV_ASSERT_NULL(RWKV_ERROR_MODEL_PARAMS | RWKV_ERROR_PARAM_MISSING, rwkv_set_params(model,[&](const char * key, struct ggml_tensor *& dest) { struct ggml_tensor * tensor = parameters_ref[key]; RWKV_ENSURE_OR_FALSE_MSG(tensor, "Model parameter %s not found", key); dest = tensor; @@ -1264,11 +1264,11 @@ struct rwkv_context * rwkv_new_context_impl(std::shared_ptr inputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + std::unique_ptr inputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, inputs.get(), "Failed to allocate input state parts"); // We collect parts of output state here. Each part is (n_embed) vector. - std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state [n_layer]); + std::unique_ptr outputs(new(std::nothrow) struct rwkv_layer_state[n_layer]); RWKV_ASSERT_NULL_MSG(RWKV_ERROR_ALLOC, outputs.get(), "Failed to allocate output state parts"); for (size_t i = 0; i < n_layer; i++) { @@ -1618,7 +1618,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const // This is a histogram of quantized values. If it shows single 1.0, then all 0.0, something went very wrong! int64_t hist_all[16] {}; - std::unique_ptr scratch(new(std::nothrow) uint8_t [max_in_size + max_out_size]); + std::unique_ptr scratch(new(std::nothrow) uint8_t[max_in_size + max_out_size]); RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_ALLOC, scratch.get(), "Failed to allocate buffer"); uint8_t * in_buf = scratch.get(); From d3b6749c2d3658c3d6e36879cb5decdb78c305d1 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Wed, 14 Jun 2023 19:25:45 +0400 Subject: [PATCH 5/6] Consistently use FP16 and FP32 for rwkv.cpp data types --- README.md | 4 ++-- rwkv.cpp | 24 ++++++++++++------------ rwkv/convert_pytorch_to_ggml.py | 12 +++++++----- rwkv/convert_pytorch_to_ggml.test.py | 2 +- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index ef84c199..7b41a621 100644 --- a/README.md +++ b/README.md @@ -130,10 +130,10 @@ This option would require a little more manual work, but you can use it with any ```commandline # Windows -python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float16 +python rwkv\convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16 # Linux / MacOS -python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin float16 +python rwkv/convert_pytorch_to_ggml.py ~/Downloads/RWKV-4-Pile-169M-20220807-8023.pth ~/Downloads/rwkv.cpp-169M.bin FP16 ``` **Optionally**, quantize the model into one of quantized formats from the table above: diff --git a/rwkv.cpp b/rwkv.cpp index 1bf3efcb..c90f07a6 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -174,8 +174,8 @@ bool rwkv_fwrite_data(FILE * file, const void * data, const size_t length) { #define TYPE_UNKNOWN TYPE_COUNT enum rwkv_type { - TYPE_F32, - TYPE_F16, + TYPE_FP32, + TYPE_FP16, TYPE_Q4_0, TYPE_Q4_1, TYPE_Q4_1_O, // Unsupported @@ -204,8 +204,8 @@ extern const enum ggml_type rwkv_type_to_ggml[TYPE_COUNT + 1] = { }; extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { - TYPE_F32, /* F32 */ - TYPE_F16, /* F16 */ + TYPE_FP32, /* FP32 */ + TYPE_FP16, /* FP16 */ TYPE_Q4_0, /* Q4_0 */ TYPE_Q4_1, /* Q4_1 */ TYPE_Q4_2, /* Q4_2 */ @@ -220,7 +220,7 @@ extern const enum rwkv_type rwkv_type_from_ggml[GGML_TYPE_COUNT + 1] = { TYPE_COUNT, /* COUNT */ }; -extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"float32", "float16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; +extern const char * rwkv_type_to_string[TYPE_COUNT + 1] = {"FP32", "FP16", "Q4_0", "Q4_1", "Q4_1_O", "Q4_2", "Q4_3", "Q5_0", "Q5_1", "Q8_0", "unknown"}; enum rwkv_type rwkv_type_from_string(const char * str) { for (int ord = 0; ord < TYPE_COUNT; ord++) { @@ -408,7 +408,7 @@ bool rwkv_fread_ggml_tensor_data( RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_FILE_READ, rwkv_fread_data(file, buffer_size_bytes, buffer.get()), "Failed to read tensor data from %s", name.c_str()); // Quantization works only with FP32 values - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { std::unique_ptr float_buffer(new(std::nothrow) char[buffer_size_bytes * 2]); ggml_fp16_to_fp32_row((const ggml_fp16_t *) buffer.get(), (float *) float_buffer.get(), ggml_nelements(tensor)); @@ -1092,7 +1092,7 @@ bool rwkv_build_sequence_graph( struct ggml_tensor * x = ggml_get_rows(ctx, model.emb, tokens); x = rwkv_layer_norm(ctx, x, ggml_repeat(ctx, model.ln0_weight, x), ggml_repeat(ctx, model.ln0_bias, x)); - + for (size_t i = 0; i < model.header.n_layer; i++) { struct rwkv_layer & layer = model.layers[i]; struct rwkv_layer_state state = inputs[i]; @@ -1558,7 +1558,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const RWKV_ASSERT_FALSE_MSG( RWKV_ERROR_FILE, in_type == GGML_TYPE_F32 || in_type == GGML_TYPE_F16, - "Unsupported input data type (%s); needs to be F32 or F16", + "Unsupported input data type (%s); needs to be FP32 or FP16", rwkv_type_to_string[rwkv_type_from_ggml[in_type]] ); @@ -1571,7 +1571,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const size_t orig_total_size = 0; size_t new_total_size = 0; - // Required to init the fp16 tables + // Required to init the F16 tables // Doesn't crash if ggml_init fails ggml_free(ggml_init({ 0, NULL, true })); @@ -1590,7 +1590,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const } // f16 type tensors get relocated to out and then converted into f32 at in - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { if (in_size > max_out_size) { max_out_size = in_size; } @@ -1636,7 +1636,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const const char * name_str = name.c_str(); RWKV_MSG("%*s - [%5" PRId32 ", %5" PRId32 "], type = %6s ", (int) max_key_length, name_str, header.width, header.height, rwkv_type_to_string[header.data_type]); - data = header.data_type == TYPE_F16 ? out_buf : in_buf; + data = header.data_type == TYPE_FP16 ? out_buf : in_buf; size_t orig_size = rwkv_tensor_size(header), new_size = orig_size; RWKV_ASSERT_FALSE_MSG(RWKV_ERROR_MODEL_PARAMS, rwkv_fread_data(in_file.file, orig_size, data), "\nFailed to read tensor data of %s", name_str); @@ -1645,7 +1645,7 @@ bool rwkv_quantize_model_file(const char * in_path, const char * out_path, const size_t nelements = (size_t) header.width * (size_t) header.height; - if (header.data_type == TYPE_F16) { + if (header.data_type == TYPE_FP16) { ggml_fp16_to_fp32_row((const ggml_fp16_t *) out_buf, (float *) in_buf, nelements); } diff --git a/rwkv/convert_pytorch_to_ggml.py b/rwkv/convert_pytorch_to_ggml.py index 2ea4a48d..18debea4 100644 --- a/rwkv/convert_pytorch_to_ggml.py +++ b/rwkv/convert_pytorch_to_ggml.py @@ -1,5 +1,5 @@ # Converts an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file. -# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin float32 +# Usage: python convert_pytorch_to_ggml.py C:\RWKV-4-Pile-169M-20220807-8023.pth C:\rwkv.cpp-169M.bin FP16 # Get model checkpoints from https://huggingface.co/BlinkDL # See FILE_FORMAT.md for the documentation on the file format. @@ -12,7 +12,7 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert an RWKV model checkpoint in PyTorch format to an rwkv.cpp compatible file') parser.add_argument('src_path', help='Path to PyTorch checkpoint file') parser.add_argument('dest_path', help='Path to rwkv.cpp checkpoint file, will be overwritten') - parser.add_argument('data_type', help='Data type, float16 or float32', type=str, choices=['float16', 'float32'], default='float32') + parser.add_argument('data_type', help='Data type, FP16 or FP32', type=str, choices=['FP16', 'FP32'], default='FP16') return parser.parse_args() def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: @@ -26,6 +26,8 @@ def get_layer_count(state_dict: Dict[str, torch.Tensor]) -> int: return n_layer def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_type: str) -> None: + is_FP16 = data_type == 'FP16' or data_type == 'float16' + emb_weight: torch.Tensor = state_dict['emb.weight'] n_layer = get_layer_count(state_dict) @@ -42,7 +44,7 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t n_vocab, n_embed, n_layer, - 1 if data_type == 'float16' else 0 + 1 if is_FP16 else 0 )) for k in state_dict.keys(): @@ -56,8 +58,8 @@ def write_state_dict(state_dict: Dict[str, torch.Tensor], dest_path: str, data_t if '.time_decay' in k: tensor = -torch.exp(tensor) - # Keep 1-dim vectors in fp32 - if data_type == 'float16' and len(tensor.shape) > 1: + # Keep 1-dim vectors in FP32 + if is_FP16 and len(tensor.shape) > 1: tensor = tensor.half() shape = tensor.shape diff --git a/rwkv/convert_pytorch_to_ggml.test.py b/rwkv/convert_pytorch_to_ggml.test.py index 9ced1d05..501a85ef 100644 --- a/rwkv/convert_pytorch_to_ggml.test.py +++ b/rwkv/convert_pytorch_to_ggml.test.py @@ -13,7 +13,7 @@ def test() -> None: 'blocks.0.ln1.weight': torch.tensor([1], dtype=torch.float32) } - convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='float32') + convert_pytorch_to_ggml.write_state_dict(state_dict, dest_path=test_file_path, data_type='FP32') with open(test_file_path, 'rb') as input: actual_bytes: bytes = input.read() From c49d3d895f43e682e6beb4fa5f52babcb380dbd9 Mon Sep 17 00:00:00 2001 From: saharNooby Date: Wed, 14 Jun 2023 19:36:33 +0400 Subject: [PATCH 6/6] Add test for on-the-fly quantization --- tests/CMakeLists.txt | 7 ++- tests/test_context_cloning.c | 2 + tests/test_quantization_on_the_fly.c | 91 ++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 tests/test_quantization_on_the_fly.c diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d176f7bd..4090c9b7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,6 +12,7 @@ file(COPY tiny-rwkv-660K-FP32.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY tiny-rwkv-660K-FP16.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) file(COPY expected_logits.bin DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -rwkv_add_test(test_ggml_basics.c) -rwkv_add_test(test_tiny_rwkv.c) -rwkv_add_test(test_context_cloning.c) +file(GLOB tests *.c) +foreach (test ${tests}) + rwkv_add_test(${test}) +endforeach() diff --git a/tests/test_context_cloning.c b/tests/test_context_cloning.c index 7346c45b..6632475b 100644 --- a/tests/test_context_cloning.c +++ b/tests/test_context_cloning.c @@ -1,3 +1,5 @@ +// Tests that after context cloning evaluation gives identical results. + #include "rwkv.h" #include diff --git a/tests/test_quantization_on_the_fly.c b/tests/test_quantization_on_the_fly.c new file mode 100644 index 00000000..d63ba8b0 --- /dev/null +++ b/tests/test_quantization_on_the_fly.c @@ -0,0 +1,91 @@ +// Tests that results from on-the-fly quantized model are identical with results of pre-quantized model. + +#include "ggml.h" +#include "rwkv.h" + +#include +#include +#include + +#define N_THREADS 2 + +int main(void) { + rwkv_quantize_model_file("tiny-rwkv-660K-FP32.bin", "tiny-rwkv-660K-FP32-Q5_1.bin", "Q5_1"); + + struct rwkv_context * prequantized_ctx = rwkv_init_from_file("tiny-rwkv-660K-FP32-Q5_1.bin", N_THREADS); + + if (!prequantized_ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + // --- + + struct rwkv_init_from_file_option option = {RWKV_INIT_FROM_FILE_OPTION_TARGET_FORMAT_NAME, "Q5_1"}; + + struct rwkv_context * on_the_fly_quantized_ctx = rwkv_init_from_file_ex("tiny-rwkv-660K-FP32.bin", N_THREADS, &option, 1); + + if (!on_the_fly_quantized_ctx) { + enum rwkv_error_flags error = rwkv_get_last_error(NULL); + fprintf(stderr, "Unexpected error 0x%.8X\n", error); + return EXIT_FAILURE; + } + + // --- + + float * state = calloc(rwkv_get_state_len(prequantized_ctx), sizeof(float)); + + if (!state) { + fprintf(stderr, "Failed to allocate state\n"); + return EXIT_FAILURE; + } + + float * expected_logits = calloc(rwkv_get_logits_len(prequantized_ctx), sizeof(float)); + + if (!expected_logits) { + fprintf(stderr, "Failed to allocate logits\n"); + return EXIT_FAILURE; + } + + const unsigned char prompt[12] = "hello world"; + + rwkv_eval(prequantized_ctx, prompt[0], NULL, state, expected_logits); + + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(prequantized_ctx, prompt[i], state, state, expected_logits); + } + + // --- + + float * actual_logits = calloc(rwkv_get_logits_len(on_the_fly_quantized_ctx), sizeof(float)); + + if (!actual_logits) { + fprintf(stderr, "Failed to allocate logits\n"); + return EXIT_FAILURE; + } + + rwkv_eval(on_the_fly_quantized_ctx, prompt[0], NULL, state, actual_logits); + + for (int i = 1; prompt[i] != 0; i++) { + rwkv_eval(on_the_fly_quantized_ctx, prompt[i], state, state, actual_logits); + } + + // --- + + if (memcmp(expected_logits, actual_logits, rwkv_get_logits_len(on_the_fly_quantized_ctx) * sizeof(float))) { + fprintf(stderr, "Results not identical :(\n"); + return EXIT_FAILURE; + } else { + fprintf(stdout, "Results identical, success!\n"); + } + + rwkv_free(on_the_fly_quantized_ctx); + rwkv_free(prequantized_ctx); + + free(expected_logits); + free(actual_logits); + free(state); + + return 0; +}