@@ -68,7 +68,7 @@ def __init__(self, n_train, n_valid, n_wires):
6868
6969
7070class QModel (tq .QuantumModule ):
71- def __init__ (self , n_wires , n_blocks ):
71+ def __init__ (self , n_wires , n_blocks , add_fc = False ):
7272 super ().__init__ ()
7373 # inside one block, we have one u3 layer one each qubit and one layer
7474 # cu3 layer with ring connection
@@ -95,48 +95,56 @@ def __init__(self, n_wires, n_blocks):
9595 )
9696 )
9797 self .measure = tq .MeasureAll (tq .PauliZ )
98-
99- def forward (self , q_device : tq .QuantumDevice , input_states ):
100- # firstly set the q_device states
101- q_device .set_states (input_states )
98+ self .add_fc = add_fc
99+ if add_fc :
100+ self .fc_layer = torch .nn .Linear (n_wires , 1 )
101+
102+ def forward (self , input_states ):
103+ qdev = tq .QuantumDevice (n_wires = self .n_wires , bsz = input_states .shape [0 ], device = input_states .device )
104+ # firstly set the qdev states
105+ qdev .set_states (input_states )
102106 for k in range (self .n_blocks ):
103- self .u3_layers [k ](q_device )
104- self .cu3_layers [k ](q_device )
107+ self .u3_layers [k ](qdev )
108+ self .cu3_layers [k ](qdev )
105109
106- res = self .measure (q_device )
110+ res = self .measure (qdev )
111+ if self .add_fc :
112+ res = self .fc_layer (res )
113+ else :
114+ res = res [:, 1 ]
107115 return res
108116
109117
110- def train (dataflow , q_device , model , device , optimizer ):
118+ def train (dataflow , model , device , optimizer ):
111119 for feed_dict in dataflow ["train" ]:
112120 inputs = feed_dict ["states" ].to (device ).to (torch .complex64 )
113121 targets = feed_dict ["Xlabel" ].to (device ).to (torch .float )
114122
115- outputs = model (q_device , inputs )
123+ outputs = model (inputs )
116124
117- loss = F .mse_loss (outputs [:, 1 ] , targets )
125+ loss = F .mse_loss (outputs , targets )
118126 optimizer .zero_grad ()
119127 loss .backward ()
120128 optimizer .step ()
121129 print (f"loss: { loss .item ()} " )
122130
123131
124- def valid_test (dataflow , q_device , split , model , device ):
132+ def valid_test (dataflow , split , model , device ):
125133 target_all = []
126134 output_all = []
127135 with torch .no_grad ():
128136 for feed_dict in dataflow [split ]:
129137 inputs = feed_dict ["states" ].to (device ).to (torch .complex64 )
130138 targets = feed_dict ["Xlabel" ].to (device ).to (torch .float )
131139
132- outputs = model (q_device , inputs )
140+ outputs = model (inputs )
133141
134142 target_all .append (targets )
135143 output_all .append (outputs )
136144 target_all = torch .cat (target_all , dim = 0 )
137145 output_all = torch .cat (output_all , dim = 0 )
138146
139- loss = F .mse_loss (output_all [:, 1 ] , target_all )
147+ loss = F .mse_loss (output_all , target_all )
140148
141149 print (f"{ split } set loss: { loss } " )
142150
@@ -165,6 +173,9 @@ def main():
165173 parser .add_argument (
166174 "--epochs" , type = int , default = 100 , help = "number of training epochs"
167175 )
176+ parser .add_argument (
177+ "--addfc" , action = "store_true" , help = "add a final classical FC layer"
178+ )
168179
169180 args = parser .parse_args ()
170181
@@ -202,27 +213,23 @@ def main():
202213 use_cuda = torch .cuda .is_available ()
203214 device = torch .device ("cuda" if use_cuda else "cpu" )
204215
205- model = QModel (n_wires = args .n_wires , n_blocks = args .n_blocks ).to (device )
216+ model = QModel (n_wires = args .n_wires , n_blocks = args .n_blocks , add_fc = args . addfc ).to (device )
206217
207218 n_epochs = args .epochs
208219 optimizer = optim .Adam (model .parameters (), lr = 5e-3 , weight_decay = 1e-4 )
209220 scheduler = CosineAnnealingLR (optimizer , T_max = n_epochs )
210221
211- q_device = tq .QuantumDevice (n_wires = args .n_wires )
212- q_device .reset_states (bsz = args .bsz )
213-
214222 for epoch in range (1 , n_epochs + 1 ):
215223 # train
216- print (f"Epoch { epoch } , RL : { optimizer .param_groups [0 ]['lr' ]} " )
217- train (dataflow , q_device , model , device , optimizer )
224+ print (f"Epoch { epoch } , LR : { optimizer .param_groups [0 ]['lr' ]} " )
225+ train (dataflow , model , device , optimizer )
218226
219227 # valid
220- valid_test (dataflow , q_device , "valid" , model , device )
228+ valid_test (dataflow ,"valid" , model , device )
221229 scheduler .step ()
222230
223231 # final valid
224- valid_test (dataflow , q_device , "valid" , model , device )
225-
232+ valid_test (dataflow , "valid" , model , device )
226233
227234if __name__ == "__main__" :
228235 main ()
0 commit comments