@@ -289,8 +289,6 @@ def make_empty(
289289 device : Optional [torch .device ] = None ,
290290 pin_memory : bool = False ,
291291 requires_grad : bool = False ,
292- share_scales : bool = False ,
293- like : Optional [QuantizedTensor ] = None ,
294292 ) -> NVFP4Tensor :
295293
296294 # Canonicalize tensor attributes
@@ -310,7 +308,7 @@ def make_empty(
310308
311309 # Allocate FP4 data
312310 data = None
313- rowwise_scale_inv = None
311+ scale_inv = None
314312 amax_rowwise = None
315313 if self .rowwise_usage :
316314 data = torch .empty (
@@ -319,24 +317,12 @@ def make_empty(
319317 device = device ,
320318 pin_memory = pin_memory ,
321319 )
322- if share_scales :
323- if (
324- like is None
325- or not hasattr (like , "_rowwise_scale_inv" )
326- or like ._rowwise_scale_inv is None
327- ):
328- raise ValueError ("share_scales requested but no rowwise scale tensor provided" )
329- rowwise_scale_inv = like ._rowwise_scale_inv
330- amax_rowwise = getattr (like , "_amax_rowwise" , None )
331- else :
332- scale_shape = self .get_scale_shape (shape , columnwise = False )
333- rowwise_scale_inv = torch .empty (
334- scale_shape , dtype = torch .uint8 , device = device , pin_memory = pin_memory
335- )
336- # Allocate per tensor scale inverse. FP32 format.
337- amax_rowwise = torch .zeros (
338- 1 , dtype = torch .float32 , device = device , pin_memory = pin_memory
339- )
320+ scale_shape = self .get_scale_shape (shape , columnwise = False )
321+ scale_inv = torch .empty (
322+ scale_shape , dtype = torch .uint8 , device = device , pin_memory = pin_memory
323+ )
324+ # Allocate per tensor scale inverse. FP32 format.
325+ amax_rowwise = torch .zeros (1 , dtype = torch .float32 , device = device , pin_memory = pin_memory )
340326
341327 # Allocate FP8 data transpose if needed
342328 columnwise_data = None
@@ -352,32 +338,20 @@ def make_empty(
352338 device = device ,
353339 pin_memory = pin_memory ,
354340 )
355- if share_scales :
356- if (
357- like is None
358- or not hasattr (like , "_columnwise_scale_inv" )
359- or like ._columnwise_scale_inv is None
360- ):
361- raise ValueError (
362- "share_scales requested but no columnwise scale tensor provided"
363- )
364- columnwise_scale_inv = like ._columnwise_scale_inv
365- amax_columnwise = getattr (like , "_amax_columnwise" , None )
366- else :
367- columnwise_scale_shape = self .get_scale_shape (shape , columnwise = True )
368- columnwise_scale_inv = torch .empty (
369- columnwise_scale_shape , dtype = torch .uint8 , device = device , pin_memory = pin_memory
370- )
371- amax_columnwise = torch .zeros (
372- 1 , dtype = torch .float32 , device = device , pin_memory = pin_memory
373- )
341+ columnwise_scale_shape = self .get_scale_shape (shape , columnwise = True )
342+ columnwise_scale_inv = torch .empty (
343+ columnwise_scale_shape , dtype = torch .uint8 , device = device , pin_memory = pin_memory
344+ )
345+ amax_columnwise = torch .zeros (
346+ 1 , dtype = torch .float32 , device = device , pin_memory = pin_memory
347+ )
374348
375349 # Construct FP8 tensor
376350 return NVFP4Tensor (
377351 shape = shape ,
378352 dtype = dtype ,
379353 rowwise_data = data ,
380- rowwise_scale_inv = rowwise_scale_inv ,
354+ rowwise_scale_inv = scale_inv ,
381355 columnwise_data = columnwise_data ,
382356 columnwise_scale_inv = columnwise_scale_inv ,
383357 amax_rowwise = amax_rowwise ,
0 commit comments