1616
1717class OutputLMHead (torch .nn .Module ):
1818 """Standalone output_lm_head block extracted from PagedLlmModelV1"""
19-
19+
2020 def __init__ (self , theta : Theta , config : LlamaModelConfig ):
2121 super ().__init__ ()
2222 self .config = config
2323 self .hp = config .hp
24-
24+
2525 # Output normalization layer
2626 self .output_norm = RMSNormLayer (
27- theta ("output_norm" ),
28- epsilon = self .hp .attention_layer_norm_rms_epsilon
27+ theta ("output_norm" ), epsilon = self .hp .attention_layer_norm_rms_epsilon
2928 )
30-
29+
3130 # Output linear layer (language model head)
3231 self .output_lm_head = LinearLayer (
3332 theta ("output" ),
3433 matmul_kernel = config .matmul_kernel ,
3534 )
36-
35+
3736 def forward (self , h : torch .Tensor ) -> torch .Tensor :
3837 # Apply normalization
39- h_norm = self .output_norm (h ) # output fp16 && wieghts float32
40-
38+ h_norm = self .output_norm (h ) # output fp16 && wieghts float32
39+
4140 # Apply final linear transformation
42- logits = self .output_lm_head (h_norm ) # output && weights fp16
43-
41+ logits = self .output_lm_head (h_norm ) # output && weights fp16
42+
4443 return logits
4544
4645
47- def create_output_lm_head_from_irpa (irpa_path : str ) -> tuple [OutputLMHead , torch .Tensor ]:
46+ def create_output_lm_head_from_irpa (
47+ irpa_path : str ,
48+ ) -> tuple [OutputLMHead , torch .Tensor ]:
4849 """
4950 Create OutputLMHead module from IRPA file and generate sample input.
50-
51+
5152 Args:
5253 irpa_path: Path to the IRPA file
53-
54+
5455 Returns:
5556 Tuple of (OutputLMHead module, sample input tensor)
5657 """
5758 # Load dataset from IRPA file
5859 dataset = Dataset .load (Path (irpa_path ))
59-
60+
6061 # Create model config from dataset
6162 llama_config = LlamaModelConfig .from_dataset (
6263 dataset = dataset ,
6364 attention_kernel = "torch" ,
6465 matmul_kernel = "sharktank.asm;*" ,
6566 activation_dtype = torch .float16 ,
6667 )
67-
68+
6869 # Create the output LM head module
6970 output_lm_head = OutputLMHead (dataset .root_theta , llama_config )
70-
71+
7172 # Generate sample input tensor matching expected dimensions
7273 # Typical shape: [batch_size, seq_len, hidden_dim]
7374 # TODO: Check if there are other more suitable sizes to test.
7475 batch_size = 2
7576 seq_len = 8
76- hidden_dim = llama_config .hp .embedding_length # Use embedding_length instead of model_dim
77-
77+ hidden_dim = (
78+ llama_config .hp .embedding_length
79+ ) # Use embedding_length instead of model_dim
80+
7881 sample_input = torch .randn (
79- batch_size , seq_len , hidden_dim ,
80- dtype = llama_config .activation_dtype
82+ batch_size , seq_len , hidden_dim , dtype = llama_config .activation_dtype
8183 )
82-
84+
8385 return output_lm_head , sample_input
8486
8587
8688# Test cases
87- @pytest .mark .parametrize ("dtype,atol" , [
88- (torch .float16 , 1e-4 )
89- ])
89+ @pytest .mark .parametrize ("dtype,atol" , [(torch .float16 , 1e-4 )])
9090def test_output_lm_head_iree_vs_eager (request , dtype , atol ):
9191 """
9292 Test OutputLMHead module comparing IREE vs PyTorch eager execution.
93-
93+
9494 Use --parameters command line argument to specify the IRPA file path.
9595 """
9696 # Validate and get IRPA path
9797 irpa_path = validate_and_get_irpa_path (request )
98-
98+
9999 try :
100100 # Create module and sample input from IRPA
101- module , sample_input = create_output_lm_head_from_irpa (irpa_path )
101+ module , sample_input = create_output_lm_head_from_irpa (irpa_path )
102102 except Exception as e :
103103 pytest .skip (f"Failed to load model from IRPA: { e } " )
104104
105105 # Convert to desired dtype
106106 # module = module.to(dtype)
107107 sample_input = sample_input .to (dtype )
108-
108+
109109 # Run IREE vs torch comparison
110- run_iree_vs_torch_fx (module , input_args = (sample_input ,), atol = atol , rtol = 0 ,
111- compile_flags = LLM_HIP_COMPILE_FLAGS ,
112- parameters_path = irpa_path )
110+ run_iree_vs_torch_fx (
111+ module ,
112+ input_args = (sample_input ,),
113+ atol = atol ,
114+ rtol = 0 ,
115+ compile_flags = LLM_HIP_COMPILE_FLAGS ,
116+ parameters_path = irpa_path ,
117+ )
113118
114119
115120def test_output_lm_head_mock ():
@@ -118,10 +123,10 @@ def test_output_lm_head_mock():
118123 Adding this test to work without requiring an IRPA file.
119124 """
120125 torch .manual_seed (42 )
121-
126+
122127 # Mock configuration - provide all required parameters
123128 from sharktank .layers .configs import LlamaHParams
124-
129+
125130 # Create LlamaHParams with all required parameters
126131 hp = LlamaHParams (
127132 model_arch = "llama" ,
@@ -135,41 +140,48 @@ def test_output_lm_head_mock():
135140 attention_head_count_kv = 8 ,
136141 vocab_size = 32000 ,
137142 )
138-
143+
139144 # Create mock config
140145 config = LlamaModelConfig (
141146 hp = hp ,
142147 activation_dtype = torch .float16 ,
143148 # attention_dtype=torch.float32,
144149 )
145-
150+
146151 # Create mock theta with synthetic weights
147152 from sharktank .types import DefaultPrimitiveTensor
148-
153+
149154 # Mock output_norm weights
150155 output_norm_weight = torch .randn (hp .embedding_length , dtype = torch .float32 )
151-
152- # Mock output (lm_head) weights
156+
157+ # Mock output (lm_head) weights
153158 output_weight = torch .randn (hp .vocab_size , hp .embedding_length , dtype = torch .float16 )
154-
159+
155160 # Create theta structure
156161 theta_dict = {
157162 "output_norm" : {"weight" : DefaultPrimitiveTensor (data = output_norm_weight )},
158163 "output" : {"weight" : DefaultPrimitiveTensor (data = output_weight )},
159164 }
160-
165+
161166 theta = Theta (theta_dict )
162-
167+
163168 # Create module
164169 module = OutputLMHead (theta , config )
165-
170+
166171 # Create sample input
167172 batch_size , seq_len = 2 , 8
168- sample_input = torch .randn (batch_size , seq_len , hp .embedding_length , dtype = torch .float32 )
169-
173+ sample_input = torch .randn (
174+ batch_size , seq_len , hp .embedding_length , dtype = torch .float32
175+ )
176+
170177 # Run IREE vs torch comparison
171- run_iree_vs_torch_fx (module , input_args = (sample_input ,), atol = 1e-4 , rtol = 0 ,
172- compile_flags = LLM_HIP_COMPILE_FLAGS ,)
178+ run_iree_vs_torch_fx (
179+ module ,
180+ input_args = (sample_input ,),
181+ atol = 1e-4 ,
182+ rtol = 0 ,
183+ compile_flags = LLM_HIP_COMPILE_FLAGS ,
184+ )
173185
174186
175187if __name__ == "__main__" :
0 commit comments