-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathPrintWeights.py
More file actions
86 lines (75 loc) · 2.92 KB
/
PrintWeights.py
File metadata and controls
86 lines (75 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import division
import torch
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from ChessEnvironment import ChessEnvironment
import os
import ChessResNet
networkName = "Checkpoint 1 Weights"
model = ChessResNet.ResNetDoubleHead()
model.load_state_dict(torch.load("New Networks/weights-ckpt1.pt"))
init_layer = model.conv1.weight.data.numpy().reshape((225, 3, 3))
# WE HAVE 20 BLOCKS! oops.
blocks = []
blocks.append(model.layer1[0].conv1.weight.data.numpy())
blocks.append(model.layer1[0].conv2.weight.data.numpy())
blocks.append(model.layer1[1].conv1.weight.data.numpy())
blocks.append(model.layer1[1].conv2.weight.data.numpy())
blocks.append(model.layer2[0].conv1.weight.data.numpy())
blocks.append(model.layer2[0].conv2.weight.data.numpy())
blocks.append(model.layer2[1].conv1.weight.data.numpy())
blocks.append(model.layer2[1].conv2.weight.data.numpy())
blocks.append(model.layer3[0].conv1.weight.data.numpy())
blocks.append(model.layer3[0].conv2.weight.data.numpy())
blocks.append(model.layer3[1].conv1.weight.data.numpy())
blocks.append(model.layer3[1].conv2.weight.data.numpy())
blocks.append(model.layer4[0].conv1.weight.data.numpy())
blocks.append(model.layer4[0].conv2.weight.data.numpy())
blocks.append(model.layer4[1].conv1.weight.data.numpy())
blocks.append(model.layer4[1].conv2.weight.data.numpy())
# PRINT FIRST LAYER OF CONVOLUTIONS
plt.figure(figsize=(7, 7))
for idx, filt in enumerate(init_layer):
#print(filt[0, :, :])
plt.subplot(15, 15, idx + 1)
plt.imshow(filt[:, :], cmap="gray")
plt.axis('off')
#plt.show()
saveDirec = 'Visualization of Network/' + networkName+'/Initial Conv Layer'
if not os.path.exists(saveDirec):
os.makedirs(saveDirec)
plt.savefig(saveDirec+'/Kernels in Initial Convolutional Layer')
plt.close()
# PRINT REST OF BLOCKS
for h in range(len(blocks)):
for i in range(15):
plt.figure(figsize=(10, 10))
for idx, filt in enumerate(blocks[h]):
plt.subplot(16, 16, idx + 1)
plt.imshow(filt[i, :, :], cmap="gray")
title = "Kernels in Block " + str(int(h+1)) + ", Part " + str(int(i+1))
plt.gcf().canvas.set_window_title(title)
plt.axis('off')
saveFolder = 'Visualization of Network/'+networkName+'/Block ' + str(int(h+1))
saveDirec = 'Visualization of Network/'+networkName+'/Block ' + str(int(h+1)) + '/' + title
if not os.path.exists(saveFolder):
os.makedirs(saveFolder)
plt.savefig(saveDirec)
plt.close()
board = ChessEnvironment()
representation = board.boardToState()
"""
# PRINT BOARD AND ITS REPRESENTATION
plt.figure(figsize=(4, 4))
for idx, filt in enumerate(representation[0]):
#print(filt[0, :, :])
plt.subplot(5, 3, idx + 1)
plt.imshow(filt[:, :], cmap="gray")
plt.axis('off')
plt.show()
"""