Skip to content

Commit 6d2f43b

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent 791f183 commit 6d2f43b

File tree

2 files changed

+31
-77
lines changed

2 files changed

+31
-77
lines changed

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,6 @@ def make_empty(
231231
device: Optional[torch.device] = None,
232232
requires_grad: bool = False,
233233
pin_memory: bool = False,
234-
share_scales: bool = False,
235-
like: Optional[QuantizedTensor] = None,
236234
) -> Float8BlockwiseQTensor:
237235
"""Construct quantized tensor with uninitialized data"""
238236
if device is None:
@@ -246,25 +244,16 @@ def make_empty(
246244

247245
# Allocate FP8 data
248246
data = None
249-
rowwise_scale_inv = None
247+
scale_inv = None
250248
if self.rowwise_usage:
251249
data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory)
252-
if share_scales:
253-
if (
254-
like is None
255-
or not hasattr(like, "_rowwise_scale_inv")
256-
or like._rowwise_scale_inv is None
257-
):
258-
raise ValueError("share_scales requested but no rowwise scale tensor provided")
259-
rowwise_scale_inv = like._rowwise_scale_inv
260-
else:
261-
scale_shape = self.get_scale_shape(shape, columnwise=False)
262-
rowwise_scale_inv = torch.empty(
263-
scale_shape,
264-
dtype=torch.float32,
265-
device=device,
266-
pin_memory=pin_memory,
267-
)
250+
scale_shape = self.get_scale_shape(shape, columnwise=False)
251+
scale_inv = torch.empty(
252+
scale_shape,
253+
dtype=torch.float32,
254+
device=device,
255+
pin_memory=pin_memory,
256+
)
268257

269258
# Allocate FP8 data transpose if needed
270259
columnwise_data = None
@@ -276,30 +265,21 @@ def make_empty(
276265
device=device,
277266
pin_memory=pin_memory,
278267
)
279-
if share_scales:
280-
if (
281-
like is None
282-
or not hasattr(like, "_columnwise_scale_inv")
283-
or like._columnwise_scale_inv is None
284-
):
285-
raise ValueError("share_scales requested but no columnwise scale tensor provided")
286-
columnwise_scale_inv = like._columnwise_scale_inv
287-
else:
288-
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
289-
columnwise_scale_inv = torch.empty(
290-
columnwise_scale_shape,
291-
dtype=torch.float32,
292-
device=device,
293-
pin_memory=pin_memory,
294-
)
268+
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
269+
columnwise_scale_inv = torch.empty(
270+
columnwise_scale_shape,
271+
dtype=torch.float32,
272+
device=device,
273+
pin_memory=pin_memory,
274+
)
295275

296276
# Construct FP8 tensor
297277
return Float8BlockwiseQTensor(
298278
shape=shape,
299279
dtype=dtype,
300280
fp8_dtype=self.dtype,
301281
rowwise_data=data,
302-
rowwise_scale_inv=rowwise_scale_inv,
282+
rowwise_scale_inv=scale_inv,
303283
columnwise_data=columnwise_data,
304284
columnwise_scale_inv=columnwise_scale_inv,
305285
quantizer=self,

transformer_engine/pytorch/tensor/nvfp4_tensor.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,6 @@ def make_empty(
289289
device: Optional[torch.device] = None,
290290
pin_memory: bool = False,
291291
requires_grad: bool = False,
292-
share_scales: bool = False,
293-
like: Optional[QuantizedTensor] = None,
294292
) -> NVFP4Tensor:
295293

296294
# Canonicalize tensor attributes
@@ -310,7 +308,7 @@ def make_empty(
310308

311309
# Allocate FP4 data
312310
data = None
313-
rowwise_scale_inv = None
311+
scale_inv = None
314312
amax_rowwise = None
315313
if self.rowwise_usage:
316314
data = torch.empty(
@@ -319,24 +317,12 @@ def make_empty(
319317
device=device,
320318
pin_memory=pin_memory,
321319
)
322-
if share_scales:
323-
if (
324-
like is None
325-
or not hasattr(like, "_rowwise_scale_inv")
326-
or like._rowwise_scale_inv is None
327-
):
328-
raise ValueError("share_scales requested but no rowwise scale tensor provided")
329-
rowwise_scale_inv = like._rowwise_scale_inv
330-
amax_rowwise = getattr(like, "_amax_rowwise", None)
331-
else:
332-
scale_shape = self.get_scale_shape(shape, columnwise=False)
333-
rowwise_scale_inv = torch.empty(
334-
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
335-
)
336-
# Allocate per tensor scale inverse. FP32 format.
337-
amax_rowwise = torch.zeros(
338-
1, dtype=torch.float32, device=device, pin_memory=pin_memory
339-
)
320+
scale_shape = self.get_scale_shape(shape, columnwise=False)
321+
scale_inv = torch.empty(
322+
scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
323+
)
324+
# Allocate per tensor scale inverse. FP32 format.
325+
amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory)
340326

341327
# Allocate FP8 data transpose if needed
342328
columnwise_data = None
@@ -352,32 +338,20 @@ def make_empty(
352338
device=device,
353339
pin_memory=pin_memory,
354340
)
355-
if share_scales:
356-
if (
357-
like is None
358-
or not hasattr(like, "_columnwise_scale_inv")
359-
or like._columnwise_scale_inv is None
360-
):
361-
raise ValueError(
362-
"share_scales requested but no columnwise scale tensor provided"
363-
)
364-
columnwise_scale_inv = like._columnwise_scale_inv
365-
amax_columnwise = getattr(like, "_amax_columnwise", None)
366-
else:
367-
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
368-
columnwise_scale_inv = torch.empty(
369-
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
370-
)
371-
amax_columnwise = torch.zeros(
372-
1, dtype=torch.float32, device=device, pin_memory=pin_memory
373-
)
341+
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
342+
columnwise_scale_inv = torch.empty(
343+
columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory
344+
)
345+
amax_columnwise = torch.zeros(
346+
1, dtype=torch.float32, device=device, pin_memory=pin_memory
347+
)
374348

375349
# Construct FP8 tensor
376350
return NVFP4Tensor(
377351
shape=shape,
378352
dtype=dtype,
379353
rowwise_data=data,
380-
rowwise_scale_inv=rowwise_scale_inv,
354+
rowwise_scale_inv=scale_inv,
381355
columnwise_data=columnwise_data,
382356
columnwise_scale_inv=columnwise_scale_inv,
383357
amax_rowwise=amax_rowwise,

0 commit comments

Comments
 (0)