Skip to content

Commit 30c79ad

Browse files
committed
[minor] revise examples
1 parent da5c109 commit 30c79ad

File tree

2 files changed

+164
-23
lines changed

2 files changed

+164
-23
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
import torchquantum as tq
5+
import torchquantum.functional as tqf
6+
7+
from torchquantum.datasets import MNIST
8+
from torchquantum.operators import op_name_dict
9+
from typing import List
10+
11+
import pdb
12+
pdb.set_trace()
13+
14+
15+
16+
class TQNet(tq.QuantumModule):
17+
def __init__(self, layers: List[tq.QuantumModule], encoder=None, use_softmax=False):
18+
super().__init__()
19+
20+
self.encoder = encoder
21+
self.use_softmax = use_softmax
22+
23+
self.layers = tq.QuantumModuleList()
24+
25+
for layer in layers:
26+
self.layers.append(layer)
27+
28+
self.service = "TorchQuantum"
29+
self.measure = tq.MeasureAll(tq.PauliZ)
30+
31+
def forward(self, device, x):
32+
bsz = x.shape[0]
33+
device.reset_states(bsz)
34+
35+
x = F.avg_pool2d(x, 6)
36+
x = x.view(bsz, 16)
37+
38+
if self.encoder:
39+
self.encoder(device, x)
40+
41+
for layer in self.layers:
42+
layer(device)
43+
44+
meas = self.measure(device)
45+
46+
if self.use_softmax:
47+
meas = F.log_softmax(meas, dim=1)
48+
49+
return meas
50+
51+
class TQLayer(tq.QuantumModule):
52+
def __init__(self, gates: List[tq.QuantumModule]):
53+
super().__init__()
54+
55+
self.service = "TorchQuantum"
56+
57+
self.layer = tq.QuantumModuleList()
58+
for gate in gates:
59+
self.layer.append(gate)
60+
61+
@tq.static_support
62+
def forward(self, q_device):
63+
for gate in self.layer:
64+
gate(q_device)
65+
66+
def train_tq(model, device, train_dl, epochs, loss_fn, optimizer):
67+
losses = []
68+
for epoch in range(epochs):
69+
running_loss = 0.0
70+
batches = 0
71+
for batch_dict in train_dl:
72+
x = batch_dict['image']
73+
y = batch_dict['digit']
74+
75+
y = y.to(torch.long)
76+
77+
x = x.to(torch_device)
78+
y = y.to(torch_device)
79+
80+
optimizer.zero_grad()
81+
82+
preds = model(device, x)
83+
84+
loss = loss_fn(preds, y)
85+
loss.backward()
86+
87+
optimizer.step()
88+
89+
running_loss += loss.item()
90+
batches += 1
91+
92+
print(f"Epoch {epoch + 1} | Loss: {running_loss/batches}", end="\r")
93+
94+
print(f"Epoch {epoch + 1} | Loss: {running_loss/batches}")
95+
losses.append(running_loss/batches)
96+
97+
return losses
98+
99+
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100+
101+
# encoder = None
102+
# encoder = tq.AmplitudeEncoder()
103+
encoder = tq.MultiPhaseEncoder(['u3', 'u3', 'u3', 'u3'])
104+
105+
106+
random_layer = tq.RandomLayer(n_ops=50, wires=list(range(4)))
107+
trainable_layer = [op_name_dict['rx'](trainable=True, has_params=True, wires=[0]),
108+
op_name_dict['ry'](trainable=True, has_params=True, wires=[1]),
109+
op_name_dict['rz'](trainable=True, has_params=True, wires=[3]),
110+
op_name_dict['crx'](trainable=True, has_params=True, wires=[0,2])]
111+
trainable_layer = TQLayer(trainable_layer)
112+
layers = [random_layer, trainable_layer]
113+
114+
device = tq.QuantumDevice(n_wires=4).to(torch_device)
115+
116+
model = TQNet(layers=layers, encoder=encoder, use_softmax=True).to(torch_device)
117+
118+
loss_fn = F.nll_loss
119+
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
120+
121+
dataset = MNIST(
122+
root='./mnist_data',
123+
train_valid_split_ratio=[.9, .1],
124+
digits_of_interest=[0, 1, 3, 6],
125+
n_test_samples=200,
126+
)
127+
128+
train_dl = torch.utils.data.DataLoader(dataset['train'], batch_size=32, sampler=torch.utils.data.RandomSampler(dataset['train']))
129+
val_dl = torch.utils.data.DataLoader(dataset['valid'], batch_size=32, sampler=torch.utils.data.RandomSampler(dataset['valid']))
130+
test_dl = torch.utils.data.DataLoader(dataset['test'], batch_size=32, sampler=torch.utils.data.RandomSampler(dataset['test']))
131+
132+
print("--Training--")
133+
train_losses = train_tq(model, device, train_dl, 1, loss_fn, optimizer)
134+

examples/regression/run_regression.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, n_train, n_valid, n_wires):
6868

6969

7070
class QModel(tq.QuantumModule):
71-
def __init__(self, n_wires, n_blocks):
71+
def __init__(self, n_wires, n_blocks, add_fc=False):
7272
super().__init__()
7373
# inside one block, we have one u3 layer one each qubit and one layer
7474
# cu3 layer with ring connection
@@ -95,48 +95,56 @@ def __init__(self, n_wires, n_blocks):
9595
)
9696
)
9797
self.measure = tq.MeasureAll(tq.PauliZ)
98-
99-
def forward(self, q_device: tq.QuantumDevice, input_states):
100-
# firstly set the q_device states
101-
q_device.set_states(input_states)
98+
self.add_fc = add_fc
99+
if add_fc:
100+
self.fc_layer = torch.nn.Linear(n_wires, 1)
101+
102+
def forward(self, input_states):
103+
qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=input_states.shape[0], device=input_states.device)
104+
# firstly set the qdev states
105+
qdev.set_states(input_states)
102106
for k in range(self.n_blocks):
103-
self.u3_layers[k](q_device)
104-
self.cu3_layers[k](q_device)
107+
self.u3_layers[k](qdev)
108+
self.cu3_layers[k](qdev)
105109

106-
res = self.measure(q_device)
110+
res = self.measure(qdev)
111+
if self.add_fc:
112+
res = self.fc_layer(res)
113+
else:
114+
res = res[:, 1]
107115
return res
108116

109117

110-
def train(dataflow, q_device, model, device, optimizer):
118+
def train(dataflow, model, device, optimizer):
111119
for feed_dict in dataflow["train"]:
112120
inputs = feed_dict["states"].to(device).to(torch.complex64)
113121
targets = feed_dict["Xlabel"].to(device).to(torch.float)
114122

115-
outputs = model(q_device, inputs)
123+
outputs = model(inputs)
116124

117-
loss = F.mse_loss(outputs[:, 1], targets)
125+
loss = F.mse_loss(outputs, targets)
118126
optimizer.zero_grad()
119127
loss.backward()
120128
optimizer.step()
121129
print(f"loss: {loss.item()}")
122130

123131

124-
def valid_test(dataflow, q_device, split, model, device):
132+
def valid_test(dataflow, split, model, device):
125133
target_all = []
126134
output_all = []
127135
with torch.no_grad():
128136
for feed_dict in dataflow[split]:
129137
inputs = feed_dict["states"].to(device).to(torch.complex64)
130138
targets = feed_dict["Xlabel"].to(device).to(torch.float)
131139

132-
outputs = model(q_device, inputs)
140+
outputs = model(inputs)
133141

134142
target_all.append(targets)
135143
output_all.append(outputs)
136144
target_all = torch.cat(target_all, dim=0)
137145
output_all = torch.cat(output_all, dim=0)
138146

139-
loss = F.mse_loss(output_all[:, 1], target_all)
147+
loss = F.mse_loss(output_all, target_all)
140148

141149
print(f"{split} set loss: {loss}")
142150

@@ -165,6 +173,9 @@ def main():
165173
parser.add_argument(
166174
"--epochs", type=int, default=100, help="number of training epochs"
167175
)
176+
parser.add_argument(
177+
"--addfc", action="store_true", help="add a final classical FC layer"
178+
)
168179

169180
args = parser.parse_args()
170181

@@ -202,27 +213,23 @@ def main():
202213
use_cuda = torch.cuda.is_available()
203214
device = torch.device("cuda" if use_cuda else "cpu")
204215

205-
model = QModel(n_wires=args.n_wires, n_blocks=args.n_blocks).to(device)
216+
model = QModel(n_wires=args.n_wires, n_blocks=args.n_blocks, add_fc=args.addfc).to(device)
206217

207218
n_epochs = args.epochs
208219
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
209220
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
210221

211-
q_device = tq.QuantumDevice(n_wires=args.n_wires)
212-
q_device.reset_states(bsz=args.bsz)
213-
214222
for epoch in range(1, n_epochs + 1):
215223
# train
216-
print(f"Epoch {epoch}, RL: {optimizer.param_groups[0]['lr']}")
217-
train(dataflow, q_device, model, device, optimizer)
224+
print(f"Epoch {epoch}, LR: {optimizer.param_groups[0]['lr']}")
225+
train(dataflow, model, device, optimizer)
218226

219227
# valid
220-
valid_test(dataflow, q_device, "valid", model, device)
228+
valid_test(dataflow,"valid", model, device)
221229
scheduler.step()
222230

223231
# final valid
224-
valid_test(dataflow, q_device, "valid", model, device)
225-
232+
valid_test(dataflow, "valid", model, device)
226233

227234
if __name__ == "__main__":
228235
main()

0 commit comments

Comments
 (0)