Skip to content

Commit 7b361ad

Browse files
committed
version and compatibility fix
1 parent 9937528 commit 7b361ad

File tree

5 files changed

+50
-71
lines changed

5 files changed

+50
-71
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
git clone https://github.com/peterliht/knowledge-distillation-pytorch.git
2020
```
2121

22+
* Install python==3.10.15 and create virtualenv
23+
```
24+
sudo apt install python3.10 python3.10-venv python3.10-dev
25+
python3.10 -m venv venv
26+
source venv/bin/activate
27+
```
28+
29+
2230
* Install the dependencies (including Pytorch)
2331
```
2432
pip install -r requirements.txt

evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def evaluate(model, loss_fn, dataloader, metrics, params):
4141

4242
# move to GPU if available
4343
if params.cuda:
44-
data_batch, labels_batch = data_batch.cuda(async=True), labels_batch.cuda(async=True)
44+
data_batch, labels_batch = data_batch.cuda(), labels_batch.cuda()
4545
# fetch the next evaluation batch
4646
data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)
4747

@@ -94,7 +94,7 @@ def evaluate_kd(model, dataloader, metrics, params):
9494

9595
# move to GPU if available
9696
if params.cuda:
97-
data_batch, labels_batch = data_batch.cuda(async=True), labels_batch.cuda(async=True)
97+
data_batch, labels_batch = data_batch.cuda(), labels_batch.cuda()
9898
# fetch the next evaluation batch
9999
data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)
100100

requirements.txt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
scipy==1.0.0
2-
numpy==1.14.0
3-
Pillow==8.1.1
4-
tabulate==0.8.2
5-
tensorflow==1.7.0rc0
6-
torch==0.3.0.post4
7-
torchvision==0.2.0
1+
scipy==1.14.1
2+
numpy==1.25.0
3+
Pillow==9.0.0
4+
tabulate==0.5
5+
tensorflow==2.8.0rc0
6+
torch==1.13.0
7+
torchvision==0.14.0
88
tqdm==4.19.8
99
torchnet
10+
protobuf == 3.20

train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def train(model, optimizer, loss_fn, dataloader, metrics, params):
5858
for i, (train_batch, labels_batch) in enumerate(dataloader):
5959
# move to GPU if available
6060
if params.cuda:
61-
train_batch, labels_batch = train_batch.cuda(async=True), \
62-
labels_batch.cuda(async=True)
61+
train_batch, labels_batch = train_batch.cuda(), \
62+
labels_batch.cuda()
6363
# convert to torch Variables
6464
train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)
6565

@@ -186,8 +186,8 @@ def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, p
186186
for i, (train_batch, labels_batch) in enumerate(dataloader):
187187
# move to GPU if available
188188
if params.cuda:
189-
train_batch, labels_batch = train_batch.cuda(async=True), \
190-
labels_batch.cuda(async=True)
189+
train_batch, labels_batch = train_batch.cuda(), \
190+
labels_batch.cuda()
191191
# convert to torch Variables
192192
train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)
193193

@@ -199,7 +199,7 @@ def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, p
199199
with torch.no_grad():
200200
output_teacher_batch = teacher_model(train_batch)
201201
if params.cuda:
202-
output_teacher_batch = output_teacher_batch.cuda(async=True)
202+
output_teacher_batch = output_teacher_batch.cuda()
203203

204204
loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params)
205205

@@ -213,17 +213,17 @@ def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, p
213213
# Evaluate summaries only once in a while
214214
if i % params.save_summary_steps == 0:
215215
# extract data from torch Variable, move to cpu, convert to numpy arrays
216-
output_batch = output_batch.data.cpu().numpy()
217-
labels_batch = labels_batch.data.cpu().numpy()
216+
output_batch = output_batch.detach().cpu().numpy()
217+
labels_batch = labels_batch.detach().cpu().numpy()
218218

219219
# compute all metrics on this batch
220220
summary_batch = {metric:metrics[metric](output_batch, labels_batch)
221221
for metric in metrics}
222-
summary_batch['loss'] = loss.data[0]
222+
summary_batch['loss'] = loss.item()
223223
summ.append(summary_batch)
224224

225225
# update the average loss
226-
loss_avg.update(loss.data[0])
226+
loss_avg.update(loss.item())
227227

228228
t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
229229
t.update()

utils.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def load_checkpoint(checkpoint, model, optimizer=None):
148148
optimizer: (torch.optim) optional: resume optimizer from checkpoint
149149
"""
150150
if not os.path.exists(checkpoint):
151-
raise("File doesn't exist {}".format(checkpoint))
151+
raise(FileNotFoundError("File doesn't exist {}".format(checkpoint)))
152152
if torch.cuda.is_available():
153153
checkpoint = torch.load(checkpoint)
154154
else:
@@ -163,65 +163,35 @@ def load_checkpoint(checkpoint, model, optimizer=None):
163163
return checkpoint
164164

165165

166-
class Board_Logger(object):
167-
"""Tensorboard log utility"""
168-
166+
class BoardLogger:
167+
"""TensorBoard log utility for TensorFlow 2.x"""
169168
def __init__(self, log_dir):
170-
"""Create a summary writer logging to log_dir."""
171-
self.writer = tf.summary.FileWriter(log_dir)
169+
self.writer = tf.summary.create_file_writer(log_dir)
172170

173171
def scalar_summary(self, tag, value, step):
174172
"""Log a scalar variable."""
175-
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
176-
self.writer.add_summary(summary, step)
173+
with self.writer.as_default():
174+
tf.summary.scalar(tag, value, step=step)
175+
self.writer.flush()
177176

178177
def image_summary(self, tag, images, step):
179178
"""Log a list of images."""
179+
with self.writer.as_default():
180+
for i, img in enumerate(images):
181+
# Convert image to a TensorFlow-compatible format
182+
if isinstance(img, np.ndarray):
183+
img = tf.convert_to_tensor(img, dtype=tf.uint8)
184+
if img.ndim == 2: # Add channel dimension for grayscale images
185+
img = tf.expand_dims(img, axis=-1)
186+
tf.summary.image(f"{tag}/{i}", tf.expand_dims(img, 0), step=step) # Add batch dimension
187+
self.writer.flush()
180188

181-
img_summaries = []
182-
for i, img in enumerate(images):
183-
# Write the image to a string
184-
try:
185-
s = StringIO()
186-
except:
187-
s = BytesIO()
188-
scipy.misc.toimage(img).save(s, format="png")
189-
190-
# Create an Image object
191-
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
192-
height=img.shape[0],
193-
width=img.shape[1])
194-
# Create a Summary value
195-
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
196-
197-
# Create and write Summary
198-
summary = tf.Summary(value=img_summaries)
199-
self.writer.add_summary(summary, step)
200-
201189
def histo_summary(self, tag, values, step, bins=1000):
202190
"""Log a histogram of the tensor of values."""
203-
204-
# Create a histogram using numpy
205-
counts, bin_edges = np.histogram(values, bins=bins)
206-
207-
# Fill the fields of the histogram proto
208-
hist = tf.HistogramProto()
209-
hist.min = float(np.min(values))
210-
hist.max = float(np.max(values))
211-
hist.num = int(np.prod(values.shape))
212-
hist.sum = float(np.sum(values))
213-
hist.sum_squares = float(np.sum(values**2))
214-
215-
# Drop the start of the first bin
216-
bin_edges = bin_edges[1:]
217-
218-
# Add bin edges and counts
219-
for edge in bin_edges:
220-
hist.bucket_limit.append(edge)
221-
for c in counts:
222-
hist.bucket.append(c)
223-
224-
# Create and write Summary
225-
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
226-
self.writer.add_summary(summary, step)
227-
self.writer.flush()
191+
with self.writer.as_default():
192+
# Create histogram data using numpy
193+
counts, bin_edges = np.histogram(values, bins=bins)
194+
195+
# Create a histogram summary
196+
tf.summary.histogram(tag, values, step=step)
197+
self.writer.flush()

0 commit comments

Comments
 (0)