@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
538538 ct .c_int32 (blocksize ),
539539 stream ,
540540 )
541+
542+
543+ """C FUNCTIONS FOR OPTIMIZERS"""
544+ str2optimizer32bit = {
545+ "adam" : (
546+ lib .cadam32bit_grad_fp32 ,
547+ lib .cadam32bit_grad_fp16 ,
548+ lib .cadam32bit_grad_bf16 ,
549+ ),
550+ "momentum" : (
551+ lib .cmomentum32bit_grad_32 ,
552+ lib .cmomentum32bit_grad_16 ,
553+ ),
554+ "rmsprop" : (
555+ lib .crmsprop32bit_grad_32 ,
556+ lib .crmsprop32bit_grad_16 ,
557+ ),
558+ "lion" : (
559+ lib .clion32bit_grad_fp32 ,
560+ lib .clion32bit_grad_fp16 ,
561+ lib .clion32bit_grad_bf16 ,
562+ ),
563+ "adagrad" : (
564+ lib .cadagrad32bit_grad_32 ,
565+ lib .cadagrad32bit_grad_16 ,
566+ ),
567+ "lamb" : (
568+ lib .cadam32bit_grad_fp32 ,
569+ lib .cadam32bit_grad_fp16 ,
570+ lib .cadam32bit_grad_bf16 ,
571+ ),
572+ "ademamix" : (
573+ lib .cademamix32bit_grad_fp32 ,
574+ lib .cademamix32bit_grad_fp16 ,
575+ lib .cademamix32bit_grad_bf16 ,
576+ ),
577+ }
578+
579+ str2optimizer8bit_blockwise = {
580+ "adam" : (
581+ lib .cadam_8bit_blockwise_grad_fp32 ,
582+ lib .cadam_8bit_blockwise_grad_fp16 ,
583+ lib .cadam_8bit_blockwise_grad_bf16 ,
584+ ),
585+ "momentum" : (
586+ lib .cmomentum_8bit_blockwise_grad_fp32 ,
587+ lib .cmomentum_8bit_blockwise_grad_fp16 ,
588+ lib .cmomentum_8bit_blockwise_grad_bf16 ,
589+ ),
590+ "rmsprop" : (
591+ lib .crmsprop_8bit_blockwise_grad_fp32 ,
592+ lib .crmsprop_8bit_blockwise_grad_fp16 ,
593+ lib .crmsprop_8bit_blockwise_grad_bf16 ,
594+ ),
595+ "lion" : (
596+ lib .clion_8bit_blockwise_grad_fp32 ,
597+ lib .clion_8bit_blockwise_grad_fp16 ,
598+ lib .clion_8bit_blockwise_grad_bf16 ,
599+ ),
600+ "adagrad" : (
601+ lib .cadagrad_8bit_blockwise_grad_fp32 ,
602+ lib .cadagrad_8bit_blockwise_grad_fp16 ,
603+ lib .cadagrad_8bit_blockwise_grad_bf16 ,
604+ ),
605+ "ademamix" : (
606+ lib .cademamix_8bit_blockwise_grad_fp32 ,
607+ lib .cademamix_8bit_blockwise_grad_fp16 ,
608+ lib .cademamix_8bit_blockwise_grad_bf16 ,
609+ ),
610+ }
611+
612+
613+ def _optimizer_update_32bit_impl (
614+ optimizer_name : str ,
615+ g : torch .Tensor ,
616+ p : torch .Tensor ,
617+ state1 : torch .Tensor ,
618+ state2 : Optional [torch .Tensor ],
619+ unorm_vec : Optional [torch .Tensor ],
620+ max_unorm : float ,
621+ param_norm : float ,
622+ beta1 : float ,
623+ beta2 : float ,
624+ beta3 : float ,
625+ alpha : float ,
626+ eps : float ,
627+ weight_decay : float ,
628+ step : int ,
629+ lr : float ,
630+ gnorm_scale : float ,
631+ skip_zeros = False ,
632+ ) -> None :
633+ optim_fns = str2optimizer32bit .get (optimizer_name , None )
634+ if optim_fns is None :
635+ raise ValueError (
636+ f"Unsupported optimizer name: { optimizer_name } . Supported optimizers: { list (str2optimizer8bit_blockwise .keys ())} "
637+ )
638+ if g .dtype == torch .float32 :
639+ optim_func = optim_fns [0 ]
640+ elif g .dtype == torch .float16 :
641+ optim_func = optim_fns [1 ]
642+ elif g .dtype == torch .bfloat16 and len (optim_fns ) == 3 :
643+ optim_func = optim_fns [2 ]
644+ else :
645+ raise ValueError (
646+ f"Gradient+optimizer bit data type combination not supported: grad { g .dtype } , optimizer { state1 .dtype } " ,
647+ )
648+
649+ with _cuda_device_of (g ):
650+ optim_func (
651+ get_ptr (g ),
652+ get_ptr (p ),
653+ get_ptr (state1 ),
654+ get_ptr (state2 ),
655+ get_ptr (unorm_vec ),
656+ ct .c_float (max_unorm ),
657+ ct .c_float (param_norm ),
658+ ct .c_float (beta1 ),
659+ ct .c_float (beta2 ),
660+ ct .c_float (beta3 ),
661+ ct .c_float (alpha ),
662+ ct .c_float (eps ),
663+ ct .c_float (weight_decay ),
664+ ct .c_int32 (step ),
665+ ct .c_float (lr ),
666+ ct .c_float (gnorm_scale ),
667+ ct .c_bool (skip_zeros ),
668+ ct .c_int32 (g .numel ()),
669+ )
670+
671+
672+ def _optimizer_update_8bit_blockwise_impl (
673+ optimizer_name : str ,
674+ g : torch .Tensor ,
675+ p : torch .Tensor ,
676+ state1 : torch .Tensor ,
677+ state2 : Optional [torch .Tensor ],
678+ beta1 : float ,
679+ beta2 : float ,
680+ beta3 : float ,
681+ alpha : float ,
682+ eps : float ,
683+ step : int ,
684+ lr : float ,
685+ qmap1 : torch .Tensor ,
686+ qmap2 : Optional [torch .Tensor ],
687+ absmax1 : torch .Tensor ,
688+ absmax2 : Optional [torch .Tensor ],
689+ weight_decay : float ,
690+ gnorm_scale : float ,
691+ skip_zeros = False ,
692+ ) -> None :
693+ # torch._check(
694+ # g.numel() == p.numel(),
695+ # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
696+ # )
697+ # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
698+
699+ # torch._check(
700+ # g.dtype in compute_dtypes,
701+ # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
702+ # )
703+ # torch._check(
704+ # g.dtype == p.dtype,
705+ # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
706+ # )
707+ # torch._check(
708+ # state1.dtype == torch.uint8,
709+ # lambda: f"state1 must be uint8, got {state1.dtype}",
710+ # )
711+ # torch._check(
712+ # qmap1.dtype == absmax1.dtype == torch.float32,
713+ # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
714+ # )
715+ # if state2 is not None:
716+ # torch._check(
717+ # state2.dtype == torch.uint8,
718+ # lambda: f"state2 must be uint8, got {state2.dtype}",
719+ # )
720+ # torch._check(
721+ # qmap2.dtype == absmax2.dtype == torch.float32,
722+ # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
723+ # )
724+ optimizer_fns = str2optimizer8bit_blockwise .get (optimizer_name )
725+ if optimizer_fns is None :
726+ raise ValueError (
727+ f"Unsupported optimizer name: { optimizer_name } . Supported optimizers: { list (str2optimizer8bit_blockwise .keys ())} "
728+ )
729+
730+ if g .dtype == torch .float32 :
731+ optimizer_fn = optimizer_fns [0 ]
732+ elif g .dtype == torch .float16 :
733+ optimizer_fn = optimizer_fns [1 ]
734+ elif g .dtype == torch .bfloat16 :
735+ optimizer_fn = optimizer_fns [2 ]
736+ else :
737+ raise ValueError (
738+ f"Unsupported gradient dtype: { g .dtype } . Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
739+ )
740+
741+ with _cuda_device_of (g ):
742+ optimizer_fn (
743+ get_ptr (p ),
744+ get_ptr (g ),
745+ get_ptr (state1 ),
746+ get_ptr (state2 ),
747+ ct .c_float (beta1 ),
748+ ct .c_float (beta2 ),
749+ ct .c_float (beta3 ),
750+ ct .c_float (alpha ),
751+ ct .c_float (eps ),
752+ ct .c_int32 (step ),
753+ ct .c_float (lr ),
754+ get_ptr (qmap1 ),
755+ get_ptr (qmap2 ),
756+ get_ptr (absmax1 ),
757+ get_ptr (absmax2 ),
758+ ct .c_float (weight_decay ),
759+ ct .c_float (gnorm_scale ),
760+ ct .c_bool (skip_zeros ),
761+ ct .c_int32 (g .numel ()),
762+ )
763+
764+
765+ register_kernel ("bitsandbytes::optimizer_update_8bit_blockwise" , "cuda" )(_optimizer_update_8bit_blockwise_impl )
766+ register_kernel ("bitsandbytes::optimizer_update_32bit" , "cuda" )(_optimizer_update_32bit_impl )
0 commit comments