|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "ggml.h" |
| 4 | + |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +enum vision_arch { |
| 8 | + VISION_ARCH_LLAVA, |
| 9 | + VISION_ARCH_UNKNOWN, |
| 10 | +}; |
| 11 | + |
| 12 | +enum mm_patch_merge { |
| 13 | + MM_PATCH_MERGE_FLAT, |
| 14 | + MM_PATCH_MERGE_SPATIAL_UNPAD, |
| 15 | +}; |
| 16 | + |
| 17 | +struct clip_hparams { |
| 18 | + vision_arch arch = VISION_ARCH_UNKNOWN; |
| 19 | + |
| 20 | + uint32_t image_size; |
| 21 | + uint32_t patch_size; |
| 22 | + uint32_t hidden_size; |
| 23 | + uint32_t n_intermediate; |
| 24 | + uint32_t projection_dim; |
| 25 | + uint32_t n_head; |
| 26 | + uint32_t n_layer; |
| 27 | + uint32_t max_pos_embd; |
| 28 | + |
| 29 | + float eps; |
| 30 | + |
| 31 | + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; |
| 32 | + |
| 33 | + int32_t image_grid_pinpoints[32]; |
| 34 | + int32_t image_crop_resolution; |
| 35 | +}; |
| 36 | + |
| 37 | +struct clip_layer { |
| 38 | + // attention |
| 39 | + struct ggml_tensor * k_w; |
| 40 | + struct ggml_tensor * k_b; |
| 41 | + struct ggml_tensor * q_w; |
| 42 | + struct ggml_tensor * q_b; |
| 43 | + struct ggml_tensor * v_w; |
| 44 | + struct ggml_tensor * v_b; |
| 45 | + |
| 46 | + struct ggml_tensor * output_w; |
| 47 | + struct ggml_tensor * output_b; |
| 48 | + |
| 49 | + // layernorm 1 |
| 50 | + struct ggml_tensor * norm_in_w; |
| 51 | + struct ggml_tensor * norm_in_b; |
| 52 | + |
| 53 | + // ff |
| 54 | + struct ggml_tensor * ffn_up_w; |
| 55 | + struct ggml_tensor * ffn_up_b; |
| 56 | + |
| 57 | + struct ggml_tensor * ffn_down_w; |
| 58 | + struct ggml_tensor * ffn_down_b; |
| 59 | + |
| 60 | + // layernorm 2 |
| 61 | + struct ggml_tensor * norm_out_w; |
| 62 | + struct ggml_tensor * norm_out_b; |
| 63 | +}; |
| 64 | + |
| 65 | +struct clip_vision_model { |
| 66 | + struct clip_hparams hparams; |
| 67 | + |
| 68 | + // embeddings |
| 69 | + struct ggml_tensor * class_embedding; |
| 70 | + struct ggml_tensor * patch_embeddings; |
| 71 | + struct ggml_tensor * patch_bias; |
| 72 | + struct ggml_tensor * position_embeddings; |
| 73 | + |
| 74 | + struct ggml_tensor * pre_norm_w; |
| 75 | + struct ggml_tensor * pre_norm_b; |
| 76 | + |
| 77 | + std::vector<clip_layer> layers; |
| 78 | + |
| 79 | + struct ggml_tensor * post_norm_w; |
| 80 | + struct ggml_tensor * post_norm_b; |
| 81 | + |
| 82 | + struct ggml_tensor * projection; |
| 83 | + |
| 84 | + // LLaVA projection |
| 85 | + struct ggml_tensor * mm_a_w = NULL; |
| 86 | + struct ggml_tensor * mm_a_b = NULL; |
| 87 | + struct ggml_tensor * mm_b_w = NULL; |
| 88 | + struct ggml_tensor * mm_b_b = NULL; |
| 89 | + |
| 90 | + struct ggml_tensor * image_newline = NULL; |
| 91 | +}; |
0 commit comments