File tree Expand file tree Collapse file tree 3 files changed +18
-16
lines changed Expand file tree Collapse file tree 3 files changed +18
-16
lines changed Original file line number Diff line number Diff line change @@ -271,16 +271,18 @@ def run_iree_module(iree_devices: list[ireert.HalDevice]):
271271 out_logits = torch .cat (self .out_logits , dim = 1 )
272272
273273 pad_logits_shape = self .token_ids .shape [1 ] - out_logits .shape [1 ]
274+ pad_logits = torch .zeros (
275+ out_logits .shape [0 ], pad_logits_shape , out_logits .shape [2 ]
276+
274277
275- if self .out_logits .dtype == torch .float8_e4m3fnuz :
276- out_logits_as_int8 = self .out_logits .view (dtype = torch .int8 )
277- self .out_logits = torch .cat ((out_logits_as_int8 , self .pad_logits ), 1 ).to (
278- self .torch_device
279- )
280- else :
281- self .out_logits = torch .cat ((self .out_logits , self .pad_logits ), 1 ).to (
282- self .torch_device
283278 )
279+
280+ out_logits = torch .cat ((out_logits , pad_logits ), 1 ).to (self .torch_device )
281+
282+ return out_logits
283+
284+ return with_iree_device_context (run_iree_module , [self .runner .config .device ])
285+
284286 @timeit
285287 def compute_perplexity (self ):
286288 from torch .nn import CrossEntropyLoss
Original file line number Diff line number Diff line change @@ -98,7 +98,7 @@ def load_model(
9898
9999 theta = dataset .root_theta
100100
101- model = PagedLlmModelV1 (theta , self . config )
101+ model = PagedLlmModelV1 (theta , config )
102102
103103 self .generator = TorchGenerator (model , tokenizer )
104104
Original file line number Diff line number Diff line change @@ -77,13 +77,13 @@ def forward(self, x):
7777 y = ops .linear (x , weight , bias )
7878 # Unconditionally dequantize.
7979 if self .q_output is not None :
80- # Probably dont need the custom kernel to return a float32 tensor as a PlanarQuantizedTensor
81- assert y . unpack (). qs . dtype == torch . float32
82- y = self . q_output . quantize ( y .unpack ().qs )
83- if self .fake_quant :
84- return y . unpack (). dequant ()
85- return y .unpack ().qs
86-
80+ if isinstance ( y , QuantizedTensor ):
81+ # Probably dont need the custom kernel to return a float32 tensor as a PlanarQuantizedTensor
82+ assert y .unpack ().qs . dtype == torch . float32
83+ y = self .q_output . quantize ( y . unpack (). qs )
84+ if self . fake_quant :
85+ return y .unpack ().dequant ()
86+ return y . unpack (). qs
8787 if isinstance (y , QuantizedTensor ):
8888 y = y .unpack ().dequant ()
8989
You can’t perform that action at this time.
0 commit comments