Skip to content

Commit 52bead0

Browse files
Merge branch 'ggml-org:master' into master
2 parents d0b0424 + 10a0351 commit 52bead0

File tree

17 files changed

+1145
-77
lines changed

17 files changed

+1145
-77
lines changed

convert_hf_to_gguf.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,6 +3508,175 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35083508
return [(new_name, data_torch)]
35093509

35103510

3511+
@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM")
3512+
class Plamo2Model(TextModel):
3513+
model_arch = gguf.MODEL_ARCH.PLAMO2
3514+
3515+
def set_vocab(self):
3516+
# PLaMo 2 uses a custom tokenizer with a .jsonl file
3517+
# We need to handle this specially
3518+
tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
3519+
tokenizer_config_path = self.dir_model / "tokenizer_config.json"
3520+
3521+
if not tokenizer_jsonl_path.is_file():
3522+
raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
3523+
3524+
# Load tokenizer config
3525+
with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
3526+
tokenizer_config = json.load(f)
3527+
3528+
# Load tokens from JSONL file (actually a list format)
3529+
tokens = []
3530+
scores = []
3531+
toktypes = []
3532+
3533+
with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
3534+
for line_num, line in enumerate(f):
3535+
if line.strip():
3536+
token_data = json.loads(line)
3537+
# Format: [token, score, type, ?, ?, ?, ?]
3538+
token = token_data[0].encode("utf-8")
3539+
score = float(token_data[1])
3540+
token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
3541+
3542+
tokens.append(token)
3543+
scores.append(score)
3544+
3545+
# Map token type strings to GGUF token types
3546+
if token_type_str == "UNKNOWN":
3547+
toktypes.append(gguf.TokenType.UNKNOWN)
3548+
elif token_type_str == "CONTROL":
3549+
toktypes.append(gguf.TokenType.CONTROL)
3550+
elif token_type_str == "BYTE":
3551+
toktypes.append(gguf.TokenType.BYTE)
3552+
else:
3553+
# Check for PLaMo-2 special tokens
3554+
token_str = token_data[0]
3555+
if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
3556+
toktypes.append(gguf.TokenType.CONTROL)
3557+
else:
3558+
toktypes.append(gguf.TokenType.NORMAL)
3559+
3560+
vocab_size = self.hparams["vocab_size"]
3561+
if vocab_size > len(tokens):
3562+
pad_count = vocab_size - len(tokens)
3563+
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
3564+
for i in range(1, pad_count + 1):
3565+
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
3566+
scores.append(-1000.0)
3567+
toktypes.append(gguf.TokenType.UNUSED)
3568+
3569+
# Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
3570+
self.gguf_writer.add_tokenizer_model("plamo2")
3571+
self.gguf_writer.add_tokenizer_pre("default")
3572+
self.gguf_writer.add_token_list(tokens)
3573+
self.gguf_writer.add_token_scores(scores)
3574+
self.gguf_writer.add_token_types(toktypes)
3575+
3576+
# Add special tokens from config
3577+
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
3578+
token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
3579+
self.gguf_writer.add_bos_token_id(token_id)
3580+
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
3581+
token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
3582+
self.gguf_writer.add_eos_token_id(token_id)
3583+
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
3584+
token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
3585+
self.gguf_writer.add_pad_token_id(token_id)
3586+
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
3587+
token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
3588+
self.gguf_writer.add_sep_token_id(token_id)
3589+
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
3590+
token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
3591+
self.gguf_writer.add_unk_token_id(token_id)
3592+
3593+
# Add <|plamo:op|> as EOT to ensure appropriate end of generation
3594+
self.gguf_writer.add_eot_token_id(4)
3595+
3596+
self.gguf_writer.add_add_space_prefix(False)
3597+
3598+
def set_gguf_parameters(self):
3599+
hparams = self.hparams
3600+
block_count = hparams["num_hidden_layers"]
3601+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
3602+
3603+
# Which layers are Mamba layers
3604+
# PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer)
3605+
# This logic matches modeling_plamo.py's is_mamba function
3606+
mamba_step = hparams.get("mamba_step", 2)
3607+
mamba_enabled = hparams.get("mamba_enabled", True)
3608+
mamba_layers = []
3609+
3610+
if mamba_enabled:
3611+
for i in range(block_count):
3612+
if block_count <= (mamba_step // 2):
3613+
# use attention in last layer
3614+
is_mamba = (i != block_count - 1)
3615+
else:
3616+
is_mamba = (i % mamba_step) != (mamba_step // 2)
3617+
if is_mamba:
3618+
mamba_layers.append(0)
3619+
else:
3620+
mamba_layers.append(hparams.get("num_key_value_heads", 4))
3621+
3622+
if mamba_layers:
3623+
self.gguf_writer.add_head_count_kv(mamba_layers)
3624+
3625+
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
3626+
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
3627+
self.gguf_writer.add_block_count(block_count)
3628+
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
3629+
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
3630+
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
3631+
3632+
# Mamba parameters
3633+
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
3634+
self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4))
3635+
self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64))
3636+
intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128)
3637+
self.gguf_writer.add_ssm_inner_size(intermediate_size)
3638+
self.gguf_writer.add_ssm_group_count(0)
3639+
3640+
# MLP feed forward parameters (for attention layers)
3641+
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
3642+
self.gguf_writer.add_file_type(self.ftype)
3643+
3644+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3645+
del bid # unused
3646+
3647+
if name.endswith(".A_log"):
3648+
data_torch = -torch.exp(data_torch)
3649+
elif name.endswith(".dt_bias"):
3650+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
3651+
elif name.endswith(".dt_norm_weight"):
3652+
name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight"
3653+
elif name.endswith(".B_norm_weight"):
3654+
name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight"
3655+
elif name.endswith(".C_norm_weight"):
3656+
name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight"
3657+
elif name.endswith(".k_weight"):
3658+
name = name.rpartition(".k_weight")[0] + ".k.weight"
3659+
elif name.endswith(".q_weight"):
3660+
name = name.rpartition(".q_weight")[0] + ".q.weight"
3661+
elif name.endswith(".conv1d.weight"):
3662+
data_torch = torch.squeeze(data_torch) # remove (, 1, )
3663+
assert data_torch.ndim == 2
3664+
elif name.endswith(".pre_mixer_norm.weight"):
3665+
data_torch += 1.0
3666+
elif name.endswith(".post_mixer_norm.weight"):
3667+
data_torch += 1.0 / 5
3668+
elif name.endswith(".pre_mlp_norm.weight"):
3669+
data_torch += 1.0
3670+
elif name.endswith(".post_mlp_norm.weight"):
3671+
data_torch += 1.0 / (5**1.5)
3672+
elif name.endswith(".norm.weight"):
3673+
data_torch += 1.0
3674+
3675+
new_name = self.map_tensor_name(name)
3676+
3677+
return [(new_name, data_torch)]
3678+
3679+
35113680
@ModelBase.register("CodeShellForCausalLM")
35123681
class CodeShellModel(TextModel):
35133682
model_arch = gguf.MODEL_ARCH.CODESHELL

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
44

55
template<typename src_t, typename dst_t>
6-
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
6+
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
7+
GGML_UNUSED(src_f);
8+
GGML_UNUSED(dst_f);
9+
}
710

811
template<>
912
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
@@ -53,6 +56,9 @@ static __global__ void k_set_rows(
5356
const src_t* src_elem = src0_row + i00;
5457
dst_t* dst_elem = dst_row_ptr + i00;
5558
set_rows_1(src_elem, dst_elem);
59+
60+
GGML_UNUSED(ne10);
61+
GGML_UNUSED(ne13);
5662
}
5763

5864
template<typename src_t, typename dst_t>

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
28352835
return s;
28362836
};
28372837

2838+
bool rte = device->float_controls_rte_fp16;
28382839
#define CREATE_BINARY(name, namemod, spec) \
28392840
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
28402841
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2841-
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2842+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
28422843
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
28432844

28442845
CREATE_BINARY(add, , {0})
@@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
28902891
#undef CREATE_UNARY
28912892

28922893
#define CREATE_GLU(name) \
2893-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2894-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
2894+
if (device->float_controls_rte_fp16) { \
2895+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2896+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2897+
} else { \
2898+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2899+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
2900+
}
28952901

28962902
CREATE_GLU(geglu)
28972903
CREATE_GLU(reglu)

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
#version 450
22

3-
#if RTE16
4-
#extension GL_EXT_spirv_intrinsics : enable
5-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
6-
#endif // RTE16
7-
3+
#include "rte.comp"
84
#include "types.comp"
95

106
#if defined(SET_ROWS) && QUANT_K == 1

ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#extension GL_EXT_shader_16bit_storage : require
22
#extension GL_EXT_control_flow_attributes : require
33

4+
#include "rte.comp"
5+
46
layout (push_constant) uniform parameter
57
{
68
uint ne;

ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#extension GL_EXT_shader_16bit_storage : require
22

3+
#include "rte.comp"
4+
35
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
46

57
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4-
#extension GL_EXT_spirv_intrinsics: enable
54
#extension GL_EXT_control_flow_attributes : require
65

7-
#if RTE16
8-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
9-
#endif
6+
#include "rte.comp"
107

118
layout (push_constant) uniform parameter
129
{

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#include "types.comp"
22

33
#extension GL_EXT_shader_16bit_storage : require
4-
#extension GL_EXT_spirv_intrinsics: enable
54

6-
#if RTE16
7-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8-
#endif
5+
#include "rte.comp"
96

107
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
118

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
#if RTE16
3+
#extension GL_EXT_spirv_intrinsics : enable
4+
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
5+
#endif // RTE16

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,10 @@ void process_shaders() {
537537
for (auto src0_f16 : {false, true}) {
538538
for (auto src1_f16 : {false, true}) {
539539
for (auto dst_f16 : {false, true}) {
540-
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
541-
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
540+
for (auto rte : {false, true}) {
541+
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
542+
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
543+
}
542544
}
543545
}
544546
}
@@ -592,16 +594,19 @@ void process_shaders() {
592594
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
593595
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594596

595-
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
596-
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
597-
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
598-
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
599-
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
600-
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
601-
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
602-
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
603-
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
604-
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
597+
for (auto rte : {false, true}) {
598+
std::string suffix = rte ? "_rte" : "";
599+
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
600+
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
601+
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
602+
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
603+
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
604+
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
605+
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
606+
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
607+
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
608+
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
609+
}
605610

606611
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
607612
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -709,11 +714,59 @@ void write_output_files() {
709714
std::remove(path.c_str());
710715
}
711716
}
717+
718+
std::string suffixes[2] = {"_f32", "_f16"};
712719
for (const char *op : {"add", "sub", "mul", "div"}) {
713-
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
714-
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
715-
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
716-
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
720+
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
721+
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
722+
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
723+
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
724+
for (uint32_t t0 = 0; t0 < 2; ++t0) {
725+
if (t0 == 0) {
726+
data += "{";
727+
len += "{";
728+
}
729+
for (uint32_t t1 = 0; t1 < 2; ++t1) {
730+
if (t1 == 0) {
731+
data += "{";
732+
len += "{";
733+
}
734+
for (uint32_t t2 = 0; t2 < 2; ++t2) {
735+
if (t2 == 0) {
736+
data += "{";
737+
len += "{";
738+
}
739+
for (uint32_t rte = 0; rte < 2; ++rte) {
740+
if (rte == 0) {
741+
data += "{";
742+
len += "{";
743+
}
744+
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
745+
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
746+
data += "_data,";
747+
len += "_len,";
748+
if (rte == 1) {
749+
data += "}, ";
750+
len += "}, ";
751+
}
752+
}
753+
if (t2 == 1) {
754+
data += "}, ";
755+
len += "}, ";
756+
}
757+
}
758+
if (t1 == 1) {
759+
data += "}, ";
760+
len += "}, ";
761+
}
762+
}
763+
if (t0 == 1) {
764+
data += "};\n";
765+
len += "};\n";
766+
}
767+
}
768+
fprintf(src, data.c_str());
769+
fprintf(src, len.c_str());
717770
}
718771
fclose(hdr);
719772
fclose(src);

0 commit comments

Comments
 (0)