File tree Expand file tree Collapse file tree 3 files changed +13
-8
lines changed
Expand file tree Collapse file tree 3 files changed +13
-8
lines changed Original file line number Diff line number Diff line change @@ -75,6 +75,9 @@ def __init__(
7575 RuntimeError: perCh or perGrp was selected without specifying Nch, Ngrp.
7676 RuntimeError: qscheme is not allowed, or could be a typo.
7777 """
78+ # Init Nch/Ngrp to none incase they won't be set
79+ self .Nch = None
80+ self .Ngrp = None
7881 if isinstance (unit , torch .qscheme ):
7982 if "per_channel" in str (unit ):
8083 self .q_unit = "perCh"
@@ -124,8 +127,8 @@ def __repr__(self):
124127 """
125128 q_uint_str = f"qunit={ self .q_unit } "
126129 symmetric_str = f", symmetric={ self .symmetric } "
127- Nch_str = f", Nch={ self .Nch } " ,
128- Ngrp_str = f", Nch ={ self .Ngrp } " ,
130+ Nch_str = f", Nch={ self .Nch } " if self . Nch is not None else "" ,
131+ Ngrp_str = f", Ngrp ={ self .Ngrp } " if self . Ngrp is not None else " " ,
129132 single_sided_str = f", single_sided={ self .single_sided } "
130133 qlevel_lowering_str = f", qlevel_lowering={ self .qlevel_lowering } "
131134 return (
Original file line number Diff line number Diff line change @@ -89,9 +89,10 @@ def __init__(
8989 use_PT_native_Qfunc = kwargs .get ("use_PT_native_Qfunc" , False ),
9090 )
9191
92- with torch .no_grad ():
93- self .clip_valn .data *= init_clip_valn
94- self .clip_val .data *= init_clip_val
92+ if not self .training :
93+ with torch .no_grad ():
94+ self .clip_valn .data *= init_clip_valn
95+ self .clip_val .data *= init_clip_val
9596
9697 self .align_zero = align_zero
9798 self .clipSTE = clipSTE
Original file line number Diff line number Diff line change @@ -92,9 +92,10 @@ def __init__(
9292 use_PT_native_Qfunc = kwargs .get ("use_PT_native_Qfunc" , False ),
9393 )
9494
95- with torch .no_grad ():
96- self .clip_valn .data *= init_clip_valn
97- self .clip_val .data *= init_clip_val
95+ if not self .training :
96+ with torch .no_grad ():
97+ self .clip_valn .data *= init_clip_valn
98+ self .clip_val .data *= init_clip_val
9899
99100 self .clipSTE = clipSTE
100101 self .align_zero = align_zero
You can’t perform that action at this time.
0 commit comments