@@ -79,9 +79,11 @@ def __init__(
7979 RuntimeError: perCh or perGrp was selected without specifying Nch, Ngrp.
8080 RuntimeError: qscheme is not allowed, or could be a typo.
8181 """
82- # Init Nch/Ngrp to none incase they won't be set
82+ # Init perCh/perGrp vars to none incase they won't be set
8383 self .Nch = None
8484 self .Ngrp = None
85+ self .NperGrp = None
86+ self .axis = None
8587 if isinstance (unit , torch .qscheme ):
8688 if "per_channel" in str (unit ):
8789 self .q_unit = "perCh"
@@ -97,13 +99,19 @@ def __init__(
9799 self .q_unit = unit
98100 self .symmetric = symmetric
99101 if unit == "perCh" :
100- if issubclass (type (Nch ), int ):
102+ if Nch is not None and issubclass (type (Nch ), int ):
101103 assert Nch > 0 , "Provided Nch is negative"
102104 self .Nch = Nch
103105 else :
104106 raise RuntimeError (
105107 "perCh was selected without specifying Nch."
106108 )
109+ if axis is not None and issubclass (type (axis ), int ):
110+ self .axis = axis
111+ else :
112+ raise RuntimeError (
113+ "perCh was selected without specifying channel axis dimension."
114+ )
107115 elif unit == "perGrp" :
108116 if issubclass (type (NperGrp ), int ):
109117 assert NperGrp > 0 , "Provided NperGrp is negative"
@@ -130,9 +138,9 @@ def __repr__(self):
130138 """
131139 q_uint_str = f"qunit={ self .q_unit } "
132140 symmetric_str = f", symmetric={ self .symmetric } "
133- Nch_str = f", Nch={ self .Nch } " if self .Nch is not None else "" ,
134- Ngrp_str = f", NperGrp={ self .NperGrp } " if self .NperGrp else "" ,
135- NperGrp_str = f", Ngrp={ self .NperGrp } " if self .NperGrp is not None else "" ,
141+ Nch_str = f", Nch={ self .Nch } " if self .Nch is not None else ""
142+ Ngrp_str = f", NperGrp={ self .Ngrp } " if self .Ngrp is not None else ""
143+ NperGrp_str = f", Ngrp={ self .NperGrp } " if self .NperGrp is not None else ""
136144 single_sided_str = f", single_sided={ self .single_sided } "
137145 qlevel_lowering_str = f", qlevel_lowering={ self .qlevel_lowering } "
138146 return (
0 commit comments