|  | 
| 17 | 17 | 
 | 
| 18 | 18 | import ai_edge_torch.generative.layers.model_config as cfg | 
| 19 | 19 | from ai_edge_torch.generative.utilities import model_builder | 
|  | 20 | +from torch import nn | 
| 20 | 21 | 
 | 
| 21 | 22 | TENSOR_NAMES = model_builder.TENSOR_NAMES | 
| 22 | 23 | 
 | 
| 23 | 24 | 
 | 
|  | 25 | +class Qwen(model_builder.DecoderOnlyModel): | 
|  | 26 | +  """A Qwen model built from the Edge Generative API layers.""" | 
|  | 27 | +  pass | 
|  | 28 | + | 
|  | 29 | + | 
| 24 | 30 | def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig: | 
| 25 | 31 |   """Returns the model config for a Qwen 2.5 3B model. | 
| 26 | 32 | 
 | 
| @@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig: | 
| 101 | 107 |   return config | 
| 102 | 108 | 
 | 
| 103 | 109 | 
 | 
| 104 |  | -def build_3b_model( | 
| 105 |  | -    checkpoint_path: str, **kwargs | 
| 106 |  | -) -> model_builder.DecoderOnlyModel: | 
|  | 110 | +def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module: | 
| 107 | 111 |   return model_builder.build_decoder_only_model( | 
| 108 | 112 |       checkpoint_path=checkpoint_path, | 
| 109 | 113 |       config=get_3b_model_config(**kwargs), | 
| 110 | 114 |       tensor_names=TENSOR_NAMES, | 
|  | 115 | +      model_class=Qwen, | 
| 111 | 116 |   ) | 
| 112 | 117 | 
 | 
| 113 | 118 | 
 | 
| 114 |  | -def build_1_5b_model( | 
| 115 |  | -    checkpoint_path: str, **kwargs | 
| 116 |  | -) -> model_builder.DecoderOnlyModel: | 
|  | 119 | +def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module: | 
| 117 | 120 |   return model_builder.build_decoder_only_model( | 
| 118 | 121 |       checkpoint_path=checkpoint_path, | 
| 119 | 122 |       config=get_1_5b_model_config(**kwargs), | 
| 120 | 123 |       tensor_names=TENSOR_NAMES, | 
|  | 124 | +      model_class=Qwen, | 
| 121 | 125 |   ) | 
| 122 | 126 | 
 | 
| 123 | 127 | 
 | 
| 124 |  | -def build_0_5b_model( | 
| 125 |  | -    checkpoint_path: str, **kwargs | 
| 126 |  | -) -> model_builder.DecoderOnlyModel: | 
|  | 128 | +def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module: | 
| 127 | 129 |   return model_builder.build_decoder_only_model( | 
| 128 | 130 |       checkpoint_path=checkpoint_path, | 
| 129 | 131 |       config=get_0_5b_model_config(**kwargs), | 
| 130 | 132 |       tensor_names=TENSOR_NAMES, | 
|  | 133 | +      model_class=Qwen, | 
| 131 | 134 |   ) | 
0 commit comments