Skip to content

Commit a0e5031

Browse files
committed
Changes from Archana's branch
1 parent 96341ac commit a0e5031

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

sharktank/sharktank/evaluate/perplexity_iree.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,16 +271,18 @@ def run_iree_module(iree_devices: list[ireert.HalDevice]):
271271
out_logits = torch.cat(self.out_logits, dim=1)
272272

273273
pad_logits_shape = self.token_ids.shape[1] - out_logits.shape[1]
274+
pad_logits = torch.zeros(
275+
out_logits.shape[0], pad_logits_shape, out_logits.shape[2]
276+
274277

275-
if self.out_logits.dtype == torch.float8_e4m3fnuz:
276-
out_logits_as_int8 = self.out_logits.view(dtype=torch.int8)
277-
self.out_logits = torch.cat((out_logits_as_int8, self.pad_logits), 1).to(
278-
self.torch_device
279-
)
280-
else:
281-
self.out_logits = torch.cat((self.out_logits, self.pad_logits), 1).to(
282-
self.torch_device
283278
)
279+
280+
out_logits = torch.cat((out_logits, pad_logits), 1).to(self.torch_device)
281+
282+
return out_logits
283+
284+
return with_iree_device_context(run_iree_module, [self.runner.config.device])
285+
284286
@timeit
285287
def compute_perplexity(self):
286288
from torch.nn import CrossEntropyLoss

sharktank/sharktank/evaluate/perplexity_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def load_model(
9898

9999
theta = dataset.root_theta
100100

101-
model = PagedLlmModelV1(theta, self.config)
101+
model = PagedLlmModelV1(theta, config)
102102

103103
self.generator = TorchGenerator(model, tokenizer)
104104

sharktank/sharktank/layers/linear.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ def forward(self, x):
7777
y = ops.linear(x, weight, bias)
7878
# Unconditionally dequantize.
7979
if self.q_output is not None:
80-
# Probably dont need the custom kernel to return a float32 tensor as a PlanarQuantizedTensor
81-
assert y.unpack().qs.dtype == torch.float32
82-
y = self.q_output.quantize(y.unpack().qs)
83-
if self.fake_quant:
84-
return y.unpack().dequant()
85-
return y.unpack().qs
86-
80+
if isinstance(y, QuantizedTensor):
81+
# Probably dont need the custom kernel to return a float32 tensor as a PlanarQuantizedTensor
82+
assert y.unpack().qs.dtype == torch.float32
83+
y = self.q_output.quantize(y.unpack().qs)
84+
if self.fake_quant:
85+
return y.unpack().dequant()
86+
return y.unpack().qs
8787
if isinstance(y, QuantizedTensor):
8888
y = y.unpack().dequant()
8989

0 commit comments

Comments
 (0)