-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbench.py
More file actions
116 lines (86 loc) · 3.47 KB
/
bench.py
File metadata and controls
116 lines (86 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models import InputWiseGateLayer, OriginalParametrizationLayer, thermometer_encode
# -- CONFIGURATION --
BATCH_SIZE = 1
LEARNING_RATE = 0.01
MAX_STEPS = 5000 # short run for demonstration purposes
DEVICE = "cpu"
NUM_GATES = 128 # width of the logic layer
THRESHOLD_K = 15 # For thermometer encoding
def get_cifar_loader():
"""
Load CIFAR-10 with standard normalization.
"""
transform = transforms.Compose([
transforms.ToTensor(),
# Normalize to 0-1 range is crucial for thermometer encoding
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
def train_benchmark(model_type:"IWP"):
print(f"\n--- Starting Benchmark: {model_type} ---")
# 1. Initialize Model
if model_type == "IWP":
# 4 params per gate
layer = InputWiseGateLayer.InputWiseGateLayer(NUM_GATES).to(DEVICE)
else:
# 16 params per gate
layer = OriginalParametrizationLayer.OriginalParametrizationLayer(NUM_GATES).to(DEVICE)
# Simple linear classifier on top (Differentiable Logic Gate Network)
classifier = nn.Linear(NUM_GATES, 10).to(DEVICE)
# Optimizer
optimizer = optim.SGD(list(layer.parameters()) + list(classifier.parameters()), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()
dataloader = get_cifar_loader()
# Metrics
start_time = time.time()
steps = 0
running_loss = 0.0
model_size = sum(p.numel() for p in layer.parameters())
print(f"Logic Layer Parameters: {model_size} (Expeected ~{NUM_GATES*4} for IWP, ~{NUM_GATES*16} for OP)")
# Training Loop
layer.train()
classifier.train()
for i, (images, labels) in enumerate(dataloader):
if steps >= MAX_STEPS:
break
images, labels = images.to(DEVICE), labels.to(DEVICE)
# A. Preprocessing (Thermometer Encoding)
# Flatten: (1, 3, 32, 32) -> (1, 3072)
x_flat = images.view(images.size(0), -1)
# Encode: (1, 3072) -> (1, 3072 * K)
x_encoded = thermometer_encode.thermometer_encode(x_flat, THRESHOLD_K)
x_input = x_encoded[:, :NUM_GATES * 2]
# B. Forward Pass
optimizer.zero_grad()
gates_output = layer(x_input)
logits = classifier(gates_output)
loss = criterion(logits, labels)
# C. Backward Pass (Measure this speed!)
loss.backward()
optimizer.step()
running_loss += loss.item()
steps += 1
if steps % 1000 == 0:
elapsed = time.time() - start_time
avg_loss = running_loss / 1000
print(f"Step {steps}: Loss {avg_loss:.4f} | Elapsed: {elapsed:.2f}s")
running_loss = 0.0
total_time = time.time() - start_time
print(f"Benchmark Complete. Total Time: {total_time:.4f}s")
print(f"Time per step: {(total_time / steps) * 1000:.2f}ms")
return total_time, steps
if __name__ == "__main__":
# Define thermometer_encode function here if not imported
# Run Comparison
time_op, _ = train_benchmark("OP")
time_iwp, _ = train_benchmark("IWP")
print("\n--- RESULTS ---")
print(f"Original Parametrization Time: {time_op:.2f}s")
print(f"Input-Wise Parametrization Time: {time_iwp:.2f}s")
print(f"Speedup: {time_op / time_iwp:.2f}x")