|
29 | 29 | _ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention' |
30 | 30 | _FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward' |
31 | 31 | _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding' |
| 32 | +# TODO: b/415833584 - Improve the regex for pre-softmax layer. |
| 33 | +_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall' |
32 | 34 | _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}' |
33 | 35 |
|
34 | 36 |
|
@@ -95,10 +97,11 @@ def _set_quant_config( |
95 | 97 | rm: quantizer.recipe_manager.RecipeManager, |
96 | 98 | layer_recipe: quant_recipe.LayerQuantRecipe, |
97 | 99 | regex: str, |
| 100 | + operation_name: _OpName = _OpName.ALL_SUPPORTED, |
98 | 101 | ): |
99 | 102 | rm.add_quantization_config( |
100 | 103 | regex=regex, |
101 | | - operation_name=_OpName.ALL_SUPPORTED, |
| 104 | + operation_name=operation_name, |
102 | 105 | op_config=_OpQuantConfig( |
103 | 106 | weight_tensor_config=_TensorQuantConfig( |
104 | 107 | num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype), |
@@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe( |
126 | 129 |
|
127 | 130 | if recipe.embedding is not None: |
128 | 131 | _set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR) |
| 132 | + if ( |
| 133 | + recipe._model_config is not None |
| 134 | + and recipe._model_config.lm_head_share_weight_with_embedding |
| 135 | + ): |
| 136 | + _set_quant_config( |
| 137 | + rm, |
| 138 | + recipe.embedding, |
| 139 | + _DECODE_LOGITS_REGEX_STR, |
| 140 | + _OpName.FULLY_CONNECTED, |
| 141 | + ) |
129 | 142 |
|
130 | 143 | if recipe.attention is not None: |
131 | 144 | if isinstance(recipe.attention, dict): |
|
0 commit comments