@@ -182,13 +182,6 @@ def get_instance(cls):
182182 return cls ._instance
183183
184184
185- dtype2bytes = {}
186- dtype2bytes [torch .float32 ] = 4
187- dtype2bytes [torch .float16 ] = 2
188- dtype2bytes [torch .bfloat16 ] = 2
189- dtype2bytes [torch .uint8 ] = 1
190- dtype2bytes [torch .int8 ] = 1
191-
192185FIRST_CUDA_DEVICE = torch .device ("cuda" , index = 0 )
193186
194187# When multiple GPUs are present, we use a context manager to
@@ -207,7 +200,7 @@ def _cuda_device_of(a: torch.Tensor):
207200
208201
209202def get_paged (* shape , dtype = torch .float32 , device = FIRST_CUDA_DEVICE ):
210- num_bytes = dtype2bytes [ dtype ] * prod (shape )
203+ num_bytes = dtype . itemsize * prod (shape )
211204 cuda_ptr = lib .cget_managed_ptr (ct .c_size_t (num_bytes ))
212205 c_ptr = ct .cast (cuda_ptr , ct .POINTER (ct .c_int ))
213206 new_array = np .ctypeslib .as_array (c_ptr , shape = shape )
@@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
217210 return out
218211
219212
220- def prefetch_tensor (A , to_cpu = False ):
213+ def prefetch_tensor (A : torch . Tensor , to_cpu = False ):
221214 assert A .is_paged , "Only paged tensors can be prefetched!"
222215 if to_cpu :
223216 deviceid = - 1
224217 else :
225218 deviceid = A .page_deviceid
226219
227- num_bytes = dtype2bytes [A .dtype ] * A .numel ()
228- lib .cprefetch (get_ptr (A ), ct .c_size_t (num_bytes ), ct .c_int32 (deviceid ))
220+ lib .cprefetch (get_ptr (A ), ct .c_size_t (A .nbytes ), ct .c_int32 (deviceid ))
229221
230222
231223def elementwise_func (func_name , A , B , value , prefetch = True ):
@@ -499,106 +491,6 @@ def post_call(prev_device):
499491 torch .cuda .set_device (prev_device )
500492
501493
502- @deprecated (
503- "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
504- category = FutureWarning ,
505- )
506- def get_transform_func (dtype , orderA , orderOut , transpose = False ):
507- name = f'ctransform_{ (8 if dtype == torch .int8 else 32 )} _{ orderA } _to_{ orderOut } _{ "t" if transpose else "n" } '
508- if not hasattr (lib , name ):
509- print (name )
510- raise ValueError (
511- f"Transform function not supported: { orderA } to { orderOut } for data type { dtype } and transpose={ transpose } " ,
512- )
513- else :
514- return getattr (lib , name )
515-
516-
517- @deprecated (
518- "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
519- category = FutureWarning ,
520- )
521- def get_transform_buffer (shape , dtype , device , to_order , from_order = "row" , transpose = False ):
522- # init_func = torch.empty
523- init_func = torch .zeros
524- dims = len (shape )
525-
526- if dims == 2 :
527- rows = shape [0 ]
528- elif dims == 3 :
529- rows = shape [0 ] * shape [1 ]
530- cols = shape [- 1 ]
531-
532- state = (shape , to_order )
533- if transpose :
534- # swap dims
535- tmp = rows
536- rows = cols
537- cols = tmp
538- state = (shape [::- 1 ], to_order )
539-
540- if to_order == "row" or to_order == "col" :
541- return init_func (shape , dtype = dtype , device = device ), state
542- elif to_order == "col32" :
543- # blocks of 32 columns (padded)
544- cols = 32 * ((cols + 31 ) // 32 )
545- return init_func ((rows , cols ), dtype = dtype , device = device ), state
546- elif to_order == "col_turing" :
547- # blocks of 32 columns and 8 rows
548- cols = 32 * ((cols + 31 ) // 32 )
549- rows = 8 * ((rows + 7 ) // 8 )
550- return init_func ((rows , cols ), dtype = dtype , device = device ), state
551- elif to_order == "col_ampere" :
552- # blocks of 32 columns and 32 rows
553- cols = 32 * ((cols + 31 ) // 32 )
554- rows = 32 * ((rows + 31 ) // 32 )
555- return init_func ((rows , cols ), dtype = dtype , device = device ), state
556- else :
557- raise NotImplementedError (f"To_order not supported: { to_order } " )
558-
559-
560- @deprecated (
561- "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
562- category = FutureWarning ,
563- )
564- def nvidia_transform (
565- A ,
566- to_order ,
567- from_order = "row" ,
568- out = None ,
569- transpose = False ,
570- state = None ,
571- ld = None ,
572- ):
573- if state is None :
574- state = (A .shape , from_order )
575- else :
576- from_order = state [1 ]
577- if out is None :
578- out , new_state = get_transform_buffer (state [0 ], A .dtype , A .device , to_order , state [1 ])
579- else :
580- new_state = (state [1 ], to_order )
581- func = get_transform_func (A .dtype , from_order , to_order , transpose )
582-
583- shape = state [0 ]
584- if len (shape ) == 2 :
585- dim1 = ct .c_int32 (shape [0 ])
586- dim2 = ct .c_int32 (shape [1 ])
587- elif ld is not None :
588- n = prod (shape )
589- dim1 = prod ([shape [i ] for i in ld ])
590- dim2 = ct .c_int32 (n // dim1 )
591- dim1 = ct .c_int32 (dim1 )
592- else :
593- dim1 = ct .c_int32 (shape [0 ] * shape [1 ])
594- dim2 = ct .c_int32 (shape [2 ])
595-
596- ptr = CUBLAS_Context .get_instance ().get_context (A .device )
597- func (ptr , get_ptr (A ), get_ptr (out ), dim1 , dim2 )
598-
599- return out , new_state
600-
601-
602494def estimate_quantiles (
603495 A : Tensor ,
604496 out : Optional [torch .Tensor ] = None ,
@@ -1715,6 +1607,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
17151607 return current_gnorm , clip_value , gnorm_scale
17161608
17171609
1610+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
17181611def histogram_scatter_add_2d (histogram : Tensor , index1 : Tensor , index2 : Tensor , source : Tensor ):
17191612 assert len (histogram .shape ) == 2
17201613 assert histogram .dtype == torch .float32
@@ -2105,6 +1998,7 @@ def int8_mm_dequant(
21051998 return result
21061999
21072000
2001+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
21082002def get_colrow_absmax (
21092003 A : torch .Tensor ,
21102004 row_stats : Optional [torch .Tensor ] = None ,
@@ -2162,6 +2056,7 @@ def get_colrow_absmax(
21622056 return row_stats , col_stats , outlier_mask
21632057
21642058
2059+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
21652060def get_row_absmax (A : torch .Tensor , threshold = 0.0 ):
21662061 """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
21672062
@@ -2366,58 +2261,6 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
23662261 return torch .ops .bitsandbytes .int8_vectorwise_quant .default (A , threshold )
23672262
23682263
2369- @deprecated (
2370- "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
2371- category = FutureWarning ,
2372- )
2373- def transform (A , to_order , from_order = "row" , out = None , transpose = False , state = None , ld = None ):
2374- prev_device = pre_call (A .device )
2375- if state is None :
2376- state = (A .shape , from_order )
2377- else :
2378- from_order = state [1 ]
2379- if out is None :
2380- out , new_state = get_transform_buffer (state [0 ], A .dtype , A .device , to_order , state [1 ], transpose )
2381- else :
2382- new_state = (state [0 ], to_order ) # (shape, order)
2383-
2384- shape = state [0 ]
2385- if len (shape ) == 2 :
2386- dim1 = ct .c_int32 (shape [0 ])
2387- dim2 = ct .c_int32 (shape [1 ])
2388- else :
2389- dim1 = ct .c_int32 (shape [0 ] * shape [1 ])
2390- dim2 = ct .c_int32 (shape [2 ])
2391-
2392- is_on_gpu ([A , out ])
2393- if to_order == "col32" :
2394- if transpose :
2395- lib .ctransform_row2col32T (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2396- else :
2397- lib .ctransform_row2col32 (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2398- elif to_order == "col_turing" :
2399- if transpose :
2400- lib .ctransform_row2turingT (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2401- else :
2402- lib .ctransform_row2turing (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2403- elif to_order == "col_ampere" :
2404- if transpose :
2405- lib .ctransform_row2ampereT (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2406- else :
2407- lib .ctransform_row2ampere (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2408- elif to_order == "row" :
2409- if from_order == "col_turing" :
2410- lib .ctransform_turing2row (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2411- elif from_order == "col_ampere" :
2412- lib .ctransform_ampere2row (get_ptr (A ), get_ptr (out ), dim1 , dim2 )
2413- else :
2414- raise NotImplementedError (f"Transform function not implemented: From { from_order } to { to_order } " )
2415-
2416- post_call (prev_device )
2417-
2418- return out , new_state
2419-
2420-
24212264def spmm_coo (
24222265 cooA : Union [COOSparseTensor , torch .Tensor ],
24232266 B : torch .Tensor ,
@@ -2692,29 +2535,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
26922535 return x .to (dtype )
26932536 else :
26942537 return None
2695-
2696-
2697- @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
2698- def extract_outliers (A , SA , idx ):
2699- shapeA = SA [0 ]
2700- formatA = SA [1 ]
2701- assert formatA in ["col_turing" , "col_ampere" ]
2702- assert A .device .type == "cuda"
2703-
2704- out = torch .zeros ((shapeA [0 ], idx .numel ()), dtype = torch .int8 , device = A .device )
2705-
2706- idx_size = ct .c_int32 (idx .numel ())
2707- rows = ct .c_int32 (shapeA [0 ])
2708- cols = ct .c_int32 (shapeA [1 ])
2709- ptrA = get_ptr (A )
2710- ptrIdx = get_ptr (idx )
2711- ptrOut = get_ptr (out )
2712-
2713- prev_device = pre_call (A .device )
2714- if formatA == "col_turing" :
2715- lib .cextractOutliers_turing (ptrA , ptrIdx , ptrOut , idx_size , rows , cols )
2716- elif formatA == "col_ampere" :
2717- lib .cextractOutliers_ampere (ptrA , ptrIdx , ptrOut , idx_size , rows , cols )
2718- post_call (prev_device )
2719-
2720- return out
0 commit comments