11import torchquantum as tq
22import torch
3- import torch .nn .functional as F
43from torchquantum .vqe_utils import parse_hamiltonian_file
5- from torchquantum .datasets import VQE
64import random
75import numpy as np
86import argparse
97import torch .optim as optim
108
11- from torch .optim .lr_scheduler import CosineAnnealingLR , ConstantLR
9+ from torch .optim .lr_scheduler import CosineAnnealingLR
10+ from torchquantum .measurement import expval_joint_analytical
1211
1312
1413class QVQEModel (tq .QuantumModule ):
@@ -32,55 +31,54 @@ def __init__(self, arch, hamil_info):
3231 trainable = True ,
3332 circular = True
3433 ))
35- self .measure = tq .MeasureMultipleTimes (
36- obs_list = hamil_info ['hamil_list' ])
3734
38- def forward (self , q_device ):
39- q_device .reset_states (bsz = 1 )
40- for k in range (self .n_blocks ):
41- self .u3_layers [k ](q_device )
42- self .cu3_layers [k ](q_device )
43- x = self .measure (q_device )
44-
45- hamil_coefficients = torch .tensor ([hamil ['coefficient' ] for hamil in
46- self .hamil_info ['hamil_list' ]],
47- device = x .device ).double ()
48-
49- for k , hamil in enumerate (self .hamil_info ['hamil_list' ]):
50- for wire , observable in zip (hamil ['wires' ], hamil ['observables' ]):
51- if observable == 'i' :
52- x [k ][wire ] = 1
53- for wire in range (q_device .n_wires ):
54- if wire not in hamil ['wires' ]:
55- x [k ][wire ] = 1
56-
57- x = torch .cumprod (x , dim = - 1 )[:, - 1 ].double ()
58- x = torch .dot (x , hamil_coefficients )
35+ def forward (self ):
36+ qdev = tq .QuantumDevice (n_wires = self .n_wires , bsz = 1 , device = next (self .parameters ()).device )
5937
60- if x .dim () == 0 :
61- x = x .unsqueeze (0 )
62-
63- return x
38+ for k in range (self .n_blocks ):
39+ self .u3_layers [k ](qdev )
40+ self .cu3_layers [k ](qdev )
41+
42+ expval = 0
43+ for hamil in self .hamil_info ['hamil_list' ]:
44+ expval += expval_joint_analytical (qdev , observable = hamil ["pauli_string" ]) * hamil ["coeff" ]
6445
46+ return expval
6547
66- def train (dataflow , q_device , model , device , optimizer ):
67- for _ in dataflow ['train' ]:
68- outputs = model (q_device )
69- loss = outputs .mean ()
7048
49+ def train (model , optimizer , n_steps = 1 ):
50+ for _ in range (n_steps ):
51+ loss = model ()
7152 optimizer .zero_grad ()
7253 loss .backward ()
7354 optimizer .step ()
7455 print (f"Expectation of energy: { loss .item ()} " )
7556
7657
77- def valid_test (dataflow , q_device , split , model , device ):
58+ def valid_test (model ):
7859 with torch .no_grad ():
79- for _ in dataflow [ split ]:
80- outputs = model ( q_device )
81- loss = outputs . mean ( )
60+ loss = model ()
61+
62+ print ( f"validation: expectation of energy: { loss . item () } " )
8263
83- print (f"Expectation of energy: { loss } " )
64+
65+ def process_hamil_info (hamil_info ):
66+ hamil_list = hamil_info ['hamil_list' ]
67+ n_wires = hamil_info ["n_wires" ]
68+ all_info = []
69+
70+ for hamil in hamil_list :
71+ pauli_string = ""
72+ for i in range (n_wires ):
73+ if i in hamil ['wires' ]:
74+ wire = hamil ['wires' ].index (i )
75+ pauli_string += (hamil ['observables' ][wire ].upper ())
76+ else :
77+ pauli_string += "I"
78+ all_info .append ({"pauli_string" : pauli_string ,
79+ "coeff" : hamil ['coefficient' ]})
80+ hamil_info ['hamil_list' ] = all_info
81+ return hamil_info
8482
8583
8684def main ():
@@ -94,7 +92,7 @@ def main():
9492 help = 'number of training epochs' )
9593 parser .add_argument ('--epochs' , type = int , default = 100 ,
9694 help = 'number of training epochs' )
97- parser .add_argument ('--hamil_filename' , type = str , default = './h2_new .txt' ,
95+ parser .add_argument ('--hamil_filename' , type = str , default = './h2 .txt' ,
9896 help = 'number of training epochs' )
9997
10098 args = parser .parse_args ()
@@ -108,49 +106,27 @@ def main():
108106 np .random .seed (seed )
109107 torch .manual_seed (seed )
110108
111- dataset = VQE (steps_per_epoch = args .steps_per_epoch )
112-
113- dataflow = dict ()
114-
115- for split in dataset :
116- if split == 'train' :
117- sampler = torch .utils .data .RandomSampler (dataset [split ])
118- else :
119- sampler = torch .utils .data .SequentialSampler (dataset [split ])
120- dataflow [split ] = torch .utils .data .DataLoader (
121- dataset [split ],
122- batch_size = 1 ,
123- sampler = sampler ,
124- num_workers = 1 ,
125- pin_memory = True )
126-
127- hamil_info = parse_hamiltonian_file (args .hamil_filename )
109+ hamil_info = process_hamil_info (parse_hamiltonian_file (args .hamil_filename ))
128110
129111 use_cuda = torch .cuda .is_available ()
130112 device = torch .device ("cuda" if use_cuda else "cpu" )
131- model = QVQEModel (arch = {"n_blocks" : args .n_blocks },
132- hamil_info = hamil_info )
113+ model = QVQEModel (arch = {"n_blocks" : args .n_blocks }, hamil_info = hamil_info )
133114
134115 model .to (device )
135116
136117 n_epochs = args .epochs
137118 optimizer = optim .Adam (model .parameters (), lr = 5e-3 , weight_decay = 1e-4 )
138119 scheduler = CosineAnnealingLR (optimizer , T_max = n_epochs )
139120
140- q_device = tq .QuantumDevice (n_wires = hamil_info ['n_wires' ])
141- q_device .reset_states (bsz = 1 )
142-
143121 for epoch in range (1 , n_epochs + 1 ):
144122 # train
145123 print (f"Epoch { epoch } , LR: { optimizer .param_groups [0 ]['lr' ]} " )
146- train (dataflow , q_device , model , device , optimizer )
124+ train (model , optimizer , n_steps = args . steps_per_epoch )
147125
148- # valid
149- valid_test (dataflow , q_device , 'valid' , model , device )
150126 scheduler .step ()
151127
152128 # final valid
153- valid_test (dataflow , q_device , 'valid' , model , device )
129+ valid_test (model )
154130
155131
156132if __name__ == '__main__' :
0 commit comments