Skip to content

Commit 2d2268c

Browse files
authored
Merge pull request #93 from mit-han-lab/dev
[minor] add examples for 2qubit 4class mnist;
2 parents 88d8174 + da5c109 commit 2d2268c

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
use 2 qubit to perform 4 class classification,
3+
We can choose four different observables to measure the qubit state:
4+
1. XX
5+
2. YY
6+
3. ZZ
7+
4. XY
8+
"""
9+
10+
import torch
11+
import torch.nn.functional as F
12+
import torch.optim as optim
13+
import argparse
14+
15+
import torchquantum as tq
16+
import torchquantum.functional as tqf
17+
18+
from torchquantum.measurement import expval_joint_analytical
19+
20+
from torchquantum.datasets import MNIST
21+
from torch.optim.lr_scheduler import CosineAnnealingLR
22+
23+
import random
24+
import numpy as np
25+
26+
27+
class QFCModel(tq.QuantumModule):
28+
class QLayer(tq.QuantumModule):
29+
def __init__(self):
30+
super().__init__()
31+
self.n_wires = 2
32+
self.random_layer = tq.RandomLayer(
33+
n_ops=50, wires=list(range(self.n_wires))
34+
)
35+
36+
# gates with trainable parameters
37+
self.rx0 = tq.RX(has_params=True, trainable=True)
38+
self.ry0 = tq.RY(has_params=True, trainable=True)
39+
self.rz0 = tq.RZ(has_params=True, trainable=True)
40+
self.crx0 = tq.CRX(has_params=True, trainable=True)
41+
42+
def forward(self, qdev: tq.QuantumDevice):
43+
self.random_layer(qdev)
44+
45+
# some trainable gates (instantiated ahead of time)
46+
self.rx0(qdev, wires=0)
47+
self.ry0(qdev, wires=1)
48+
self.rz0(qdev, wires=0)
49+
self.crx0(qdev, wires=[0, 1])
50+
51+
def __init__(self):
52+
super().__init__()
53+
self.n_wires = 2
54+
# the encoder here is just for illustration purpose, may not be the best choice
55+
self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["2x8_rxryrzrxryrzrxry"])
56+
57+
self.q_layer = self.QLayer()
58+
59+
def forward(self, x, use_qiskit=False):
60+
qdev = tq.QuantumDevice(
61+
n_wires=self.n_wires, bsz=x.shape[0], device=x.device, record_op=True
62+
)
63+
64+
bsz = x.shape[0]
65+
x = F.avg_pool2d(x, 6).view(bsz, 16)
66+
67+
self.encoder(qdev, x)
68+
self.q_layer(qdev)
69+
obs_xx = expval_joint_analytical(qdev, "XX")
70+
obs_yy = expval_joint_analytical(qdev, "YY")
71+
obs_zz = expval_joint_analytical(qdev, "ZZ")
72+
obs_xy = expval_joint_analytical(qdev, "XY")
73+
74+
x = torch.stack([obs_xx, obs_yy, obs_zz, obs_xy], dim=1)
75+
x = F.log_softmax(x, dim=1)
76+
77+
return x
78+
79+
80+
def train(dataflow, model, device, optimizer):
81+
for feed_dict in dataflow["train"]:
82+
inputs = feed_dict["image"].to(device)
83+
targets = feed_dict["digit"].to(device)
84+
85+
outputs = model(inputs)
86+
loss = F.nll_loss(outputs, targets)
87+
optimizer.zero_grad()
88+
loss.backward()
89+
optimizer.step()
90+
print(f"loss: {loss.item()}", end="\r")
91+
92+
93+
def valid_test(dataflow, split, model, device, qiskit=False):
94+
target_all = []
95+
output_all = []
96+
with torch.no_grad():
97+
for feed_dict in dataflow[split]:
98+
inputs = feed_dict["image"].to(device)
99+
targets = feed_dict["digit"].to(device)
100+
101+
outputs = model(inputs, use_qiskit=qiskit)
102+
103+
target_all.append(targets)
104+
output_all.append(outputs)
105+
target_all = torch.cat(target_all, dim=0)
106+
output_all = torch.cat(output_all, dim=0)
107+
108+
_, indices = output_all.topk(1, dim=1)
109+
masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
110+
size = target_all.shape[0]
111+
corrects = masks.sum().item()
112+
accuracy = corrects / size
113+
loss = F.nll_loss(output_all, target_all).item()
114+
115+
print(f"{split} set accuracy: {accuracy}")
116+
print(f"{split} set loss: {loss}")
117+
118+
119+
def main():
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--static", action="store_true", help="compute with " "static mode"
123+
)
124+
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
125+
parser.add_argument(
126+
"--wires-per-block", type=int, default=2, help="wires per block int static mode"
127+
)
128+
parser.add_argument(
129+
"--epochs", type=int, default=5, help="number of training epochs"
130+
)
131+
132+
args = parser.parse_args()
133+
134+
if args.pdb:
135+
import pdb
136+
137+
pdb.set_trace()
138+
139+
seed = 0
140+
random.seed(seed)
141+
np.random.seed(seed)
142+
torch.manual_seed(seed)
143+
144+
dataset = MNIST(
145+
root="./mnist_data",
146+
train_valid_split_ratio=[0.9, 0.1],
147+
digits_of_interest=[0, 1, 2, 3],
148+
n_test_samples=100,
149+
)
150+
151+
dataflow = dict()
152+
153+
for split in dataset:
154+
sampler = torch.utils.data.RandomSampler(dataset[split])
155+
dataflow[split] = torch.utils.data.DataLoader(
156+
dataset[split],
157+
batch_size=256,
158+
sampler=sampler,
159+
num_workers=8,
160+
pin_memory=True,
161+
)
162+
163+
use_cuda = torch.cuda.is_available()
164+
device = torch.device("cuda" if use_cuda else "cpu")
165+
166+
model = QFCModel().to(device)
167+
168+
n_epochs = args.epochs
169+
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
170+
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
171+
172+
for epoch in range(1, n_epochs + 1):
173+
# train
174+
print(f"Epoch {epoch}:")
175+
train(dataflow, model, device, optimizer)
176+
print(optimizer.param_groups[0]["lr"])
177+
178+
# valid
179+
valid_test(dataflow, "valid", model, device)
180+
scheduler.step()
181+
182+
# test
183+
valid_test(dataflow, "test", model, device, qiskit=False)
184+
185+
if __name__ == "__main__":
186+
main()

torchquantum/encoding.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,24 @@ def __init__(self):
187187
{"input_idx": [0], "func": "ry", "wires": [0]},
188188
{"input_idx": [1], "func": "ry", "wires": [1]},
189189
],
190+
"2x8_rxryrzrxryrzrxry": [
191+
{"input_idx": [0], "func": "rx", "wires": [0]},
192+
{"input_idx": [1], "func": "rx", "wires": [1]},
193+
{"input_idx": [2], "func": "ry", "wires": [0]},
194+
{"input_idx": [3], "func": "ry", "wires": [1]},
195+
{"input_idx": [4], "func": "rz", "wires": [0]},
196+
{"input_idx": [5], "func": "rz", "wires": [1]},
197+
{"input_idx": [6], "func": "rx", "wires": [0]},
198+
{"input_idx": [7], "func": "rx", "wires": [1]},
199+
{"input_idx": [8], "func": "ry", "wires": [0]},
200+
{"input_idx": [9], "func": "ry", "wires": [1]},
201+
{"input_idx": [10], "func": "rz", "wires": [0]},
202+
{"input_idx": [11], "func": "rz", "wires": [1]},
203+
{"input_idx": [12], "func": "rx", "wires": [0]},
204+
{"input_idx": [13], "func": "rx", "wires": [1]},
205+
{"input_idx": [14], "func": "ry", "wires": [0]},
206+
{"input_idx": [15], "func": "ry", "wires": [1]},
207+
],
190208
"3x1_ryryry": [
191209
{"input_idx": [0], "func": "ry", "wires": [0]},
192210
{"input_idx": [1], "func": "ry", "wires": [1]},
@@ -231,6 +249,24 @@ def __init__(self):
231249
{"input_idx": [14], "func": "ry", "wires": [2]},
232250
{"input_idx": [15], "func": "ry", "wires": [3]},
233251
],
252+
"8x2_ry": [
253+
{"input_idx": [0], "func": "ry", "wires": [0]},
254+
{"input_idx": [1], "func": "ry", "wires": [1]},
255+
{"input_idx": [2], "func": "ry", "wires": [2]},
256+
{"input_idx": [3], "func": "ry", "wires": [3]},
257+
{"input_idx": [4], "func": "ry", "wires": [4]},
258+
{"input_idx": [5], "func": "ry", "wires": [5]},
259+
{"input_idx": [6], "func": "ry", "wires": [6]},
260+
{"input_idx": [7], "func": "ry", "wires": [7]},
261+
{"input_idx": [8], "func": "ry", "wires": [0]},
262+
{"input_idx": [9], "func": "ry", "wires": [1]},
263+
{"input_idx": [10], "func": "ry", "wires": [2]},
264+
{"input_idx": [11], "func": "ry", "wires": [3]},
265+
{"input_idx": [12], "func": "ry", "wires": [4]},
266+
{"input_idx": [13], "func": "ry", "wires": [5]},
267+
{"input_idx": [14], "func": "ry", "wires": [6]},
268+
{"input_idx": [15], "func": "ry", "wires": [7]},
269+
],
234270
"16_ry": [
235271
{"input_idx": [0], "func": "ry", "wires": [0]},
236272
{"input_idx": [1], "func": "ry", "wires": [1]},

0 commit comments

Comments
 (0)