Skip to content

Commit d1acf15

Browse files
authored
Merge pull request #78 from mit-han-lab/dev
[minor] fix VQE
2 parents 75ae209 + af65a30 commit d1acf15

File tree

1 file changed

+42
-66
lines changed

1 file changed

+42
-66
lines changed

examples/simple_vqe/simple_vqe.py

Lines changed: 42 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import torchquantum as tq
22
import torch
3-
import torch.nn.functional as F
43
from torchquantum.vqe_utils import parse_hamiltonian_file
5-
from torchquantum.datasets import VQE
64
import random
75
import numpy as np
86
import argparse
97
import torch.optim as optim
108

11-
from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR
9+
from torch.optim.lr_scheduler import CosineAnnealingLR
10+
from torchquantum.measurement import expval_joint_analytical
1211

1312

1413
class QVQEModel(tq.QuantumModule):
@@ -32,55 +31,54 @@ def __init__(self, arch, hamil_info):
3231
trainable=True,
3332
circular=True
3433
))
35-
self.measure = tq.MeasureMultipleTimes(
36-
obs_list=hamil_info['hamil_list'])
3734

38-
def forward(self, q_device):
39-
q_device.reset_states(bsz=1)
40-
for k in range(self.n_blocks):
41-
self.u3_layers[k](q_device)
42-
self.cu3_layers[k](q_device)
43-
x = self.measure(q_device)
44-
45-
hamil_coefficients = torch.tensor([hamil['coefficient'] for hamil in
46-
self.hamil_info['hamil_list']],
47-
device=x.device).double()
48-
49-
for k, hamil in enumerate(self.hamil_info['hamil_list']):
50-
for wire, observable in zip(hamil['wires'], hamil['observables']):
51-
if observable == 'i':
52-
x[k][wire] = 1
53-
for wire in range(q_device.n_wires):
54-
if wire not in hamil['wires']:
55-
x[k][wire] = 1
56-
57-
x = torch.cumprod(x, dim=-1)[:, -1].double()
58-
x = torch.dot(x, hamil_coefficients)
35+
def forward(self):
36+
qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=1, device=next(self.parameters()).device)
5937

60-
if x.dim() == 0:
61-
x = x.unsqueeze(0)
62-
63-
return x
38+
for k in range(self.n_blocks):
39+
self.u3_layers[k](qdev)
40+
self.cu3_layers[k](qdev)
41+
42+
expval = 0
43+
for hamil in self.hamil_info['hamil_list']:
44+
expval += expval_joint_analytical(qdev, observable=hamil["pauli_string"]) * hamil["coeff"]
6445

46+
return expval
6547

66-
def train(dataflow, q_device, model, device, optimizer):
67-
for _ in dataflow['train']:
68-
outputs = model(q_device)
69-
loss = outputs.mean()
7048

49+
def train(model, optimizer, n_steps=1):
50+
for _ in range(n_steps):
51+
loss = model()
7152
optimizer.zero_grad()
7253
loss.backward()
7354
optimizer.step()
7455
print(f"Expectation of energy: {loss.item()}")
7556

7657

77-
def valid_test(dataflow, q_device, split, model, device):
58+
def valid_test(model):
7859
with torch.no_grad():
79-
for _ in dataflow[split]:
80-
outputs = model(q_device)
81-
loss = outputs.mean()
60+
loss = model()
61+
62+
print(f"validation: expectation of energy: {loss.item()}")
8263

83-
print(f"Expectation of energy: {loss}")
64+
65+
def process_hamil_info(hamil_info):
66+
hamil_list = hamil_info['hamil_list']
67+
n_wires = hamil_info["n_wires"]
68+
all_info = []
69+
70+
for hamil in hamil_list:
71+
pauli_string = ""
72+
for i in range(n_wires):
73+
if i in hamil['wires']:
74+
wire = hamil['wires'].index(i)
75+
pauli_string += (hamil['observables'][wire].upper())
76+
else:
77+
pauli_string += "I"
78+
all_info.append({"pauli_string": pauli_string,
79+
"coeff": hamil['coefficient']})
80+
hamil_info['hamil_list'] = all_info
81+
return hamil_info
8482

8583

8684
def main():
@@ -94,7 +92,7 @@ def main():
9492
help='number of training epochs')
9593
parser.add_argument('--epochs', type=int, default=100,
9694
help='number of training epochs')
97-
parser.add_argument('--hamil_filename', type=str, default='./h2_new.txt',
95+
parser.add_argument('--hamil_filename', type=str, default='./h2.txt',
9896
help='number of training epochs')
9997

10098
args = parser.parse_args()
@@ -108,49 +106,27 @@ def main():
108106
np.random.seed(seed)
109107
torch.manual_seed(seed)
110108

111-
dataset = VQE(steps_per_epoch=args.steps_per_epoch)
112-
113-
dataflow = dict()
114-
115-
for split in dataset:
116-
if split == 'train':
117-
sampler = torch.utils.data.RandomSampler(dataset[split])
118-
else:
119-
sampler = torch.utils.data.SequentialSampler(dataset[split])
120-
dataflow[split] = torch.utils.data.DataLoader(
121-
dataset[split],
122-
batch_size=1,
123-
sampler=sampler,
124-
num_workers=1,
125-
pin_memory=True)
126-
127-
hamil_info = parse_hamiltonian_file(args.hamil_filename)
109+
hamil_info = process_hamil_info(parse_hamiltonian_file(args.hamil_filename))
128110

129111
use_cuda = torch.cuda.is_available()
130112
device = torch.device("cuda" if use_cuda else "cpu")
131-
model = QVQEModel(arch={"n_blocks": args.n_blocks},
132-
hamil_info=hamil_info)
113+
model = QVQEModel(arch={"n_blocks": args.n_blocks}, hamil_info=hamil_info)
133114

134115
model.to(device)
135116

136117
n_epochs = args.epochs
137118
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
138119
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
139120

140-
q_device = tq.QuantumDevice(n_wires=hamil_info['n_wires'])
141-
q_device.reset_states(bsz=1)
142-
143121
for epoch in range(1, n_epochs + 1):
144122
# train
145123
print(f"Epoch {epoch}, LR: {optimizer.param_groups[0]['lr']}")
146-
train(dataflow, q_device, model, device, optimizer)
124+
train(model, optimizer, n_steps=args.steps_per_epoch)
147125

148-
# valid
149-
valid_test(dataflow, q_device, 'valid', model, device)
150126
scheduler.step()
151127

152128
# final valid
153-
valid_test(dataflow, q_device, 'valid', model, device)
129+
valid_test(model)
154130

155131

156132
if __name__ == '__main__':

0 commit comments

Comments
 (0)