diff --git a/README.md b/README.md
index 49717fe..542f614 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,23 @@ CNN + MLP occured overfitting to the training data.
Relational networks shows far better results in relational questions and non-relation questions.
+## Application Demo
+retrain experiments:
+
+
+
+
+
+
+You can randomly generate and move 2D shaped objects and edit text to ask questions.
+
+ $ python application.py
+
+
+
+
## Contributions
[@gngdb](https://github.com/gngdb) speeds up the model by 10 times.
+
+[@neural022](https://github.com/neural022) and [@hhhlll21qq](https://github.com/hhhlll21qq) build application.
diff --git a/RN_1_log.csv b/RN_1_log.csv
new file mode 100644
index 0000000..13da807
--- /dev/null
+++ b/RN_1_log.csv
@@ -0,0 +1,21 @@
+epoch,train_acc_ternary,train_acc_rel,train_acc_norel,train_acc_ternary,test_acc_rel,test_acc_norel
+1,51.555355976485956,50.55825440888308,52.975996080992815,54.13306451612903,61.29032258064516,55.69556451612903
+2,53.33830013063357,66.2342831482691,55.337606139777925,52.368951612903224,72.88306451612904,56.703629032258064
+3,53.53833278902678,69.31539843239713,56.32654310907903,51.965725806451616,72.22782258064517,58.71975806451613
+4,53.70672762900065,69.78996570868713,58.261552906597,52.167338709677416,73.38709677419355,61.44153225806452
+5,53.97615937295885,70.14104343566297,60.37924559111692,52.318548387096776,72.53024193548387,63.457661290322584
+6,54.60687459177009,70.68807152188113,61.68456074461137,53.024193548387096,71.52217741935483,64.21370967741936
+7,55.12634715871979,70.62377531025473,62.19382756368387,52.671370967741936,72.58064516129032,63.608870967741936
+8,55.39782005225343,71.11058948399739,62.6500244937949,53.931451612903224,74.24395161290323,63.76008064516129
+9,55.49477465708687,71.87704114957545,63.30013063357283,54.48588709677419,74.04233870967742,66.22983870967742
+10,55.601935009797515,72.61491672109732,66.2832707380797,53.42741935483871,74.49596774193549,71.27016129032258
+11,55.741753755715216,73.1078543435663,77.11258981058133,53.78024193548387,74.44556451612904,86.18951612903226
+12,55.741753755715216,73.47015839320706,93.81633736120183,54.435483870967744,76.00806451612904,98.03427419354838
+13,55.940765839320704,76.68394839973874,98.29564010450686,54.939516129032256,80.54435483870968,98.5383064516129
+14,56.174477465708684,82.49101894186806,99.06107119529719,53.881048387096776,84.77822580645162,97.88306451612904
+15,56.41125081645983,84.75771554539517,99.32744121489223,55.39314516129032,85.38306451612904,98.7399193548387
+16,56.56841933376878,86.46513716525146,99.5009389288047,55.292338709677416,85.6350806451613,98.94153225806451
+17,56.68068256041803,87.47448563030699,99.63463422599608,56.30040322580645,87.3991935483871,98.99193548387096
+18,57.09911822338341,88.2388961463096,99.67851894186806,55.645161290322584,87.3991935483871,99.14314516129032
+19,57.25730731548008,88.99208033964729,99.7305682560418,56.149193548387096,88.20564516129032,99.19354838709677
+20,57.507348138471585,89.61258981058133,99.7601649248857,56.09879032258065,88.10483870967742,98.89112903225806
diff --git a/application.py b/application.py
new file mode 100644
index 0000000..bff7e54
--- /dev/null
+++ b/application.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Jan 18 10:29:41 2022
+
+@author: helen, george chen(neural022)
+"""
+
+import sys
+import cv2
+from PyQt5 import QtWidgets
+from utils import Ui_MainWindow, RNPredictor
+from PIL import ImageQt,Image
+from numpy import asarray
+import time
+
+
+class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
+ def __init__(self, parent=None):
+ super(MainWindow, self).__init__(parent=parent)
+ self.setupUi(self)
+ self.rn_predictor = RNPredictor()
+
+ self.pushButton.clicked.connect(self.pushButton_clicked)
+ self.pushButton_2.clicked.connect(self.label.pushButton2_clicked)
+
+ def pushButton_clicked(self):#OK
+ label_image = ImageQt.fromqpixmap(self.label.grab())
+ image_RGB = Image.new("RGB", label_image.size, (255, 255, 255))
+ image_RGB.paste(label_image, mask=label_image.split()[3])
+ image_RGB = image_RGB.resize((75, 75), Image.ANTIALIAS)
+ image_RGB.save('image_RGB.jpg', 'JPEG', quality=100)
+ img_rgb_array = asarray(image_RGB)
+ img_bgr_array = cv2.cvtColor(img_rgb_array, cv2.COLOR_RGB2BGR)
+ # print(img_bgr_array.shape)#(75,75,3)
+ time.sleep(0.5)
+ # Question ComboBox
+ # Answer label4
+ if self.comboBox.currentText() != '':
+ self.question = self.comboBox.currentText()
+ print('Question:', self.question)
+ question = self.rn_predictor.tokenize(self.question)
+ self.answer = self.rn_predictor.predict((img_bgr_array/255, question))
+ print('Answer:', self.answer)
+ self.label4.setText(self.answer)
+
+
+if __name__ == "__main__":
+ app = QtWidgets.QApplication([])
+ app.setStyle('Fusion')
+ w = MainWindow()
+ w.show()
+
+ sys.exit(app.exec())
+
\ No newline at end of file
diff --git a/best_model/epoch_RN_20.pth b/best_model/epoch_RN_20.pth
new file mode 100644
index 0000000..90117d8
Binary files /dev/null and b/best_model/epoch_RN_20.pth differ
diff --git a/environment.yml b/environment.yml
index b3ac08c..9d708ea 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,28 +1,103 @@
name: RN3
channels:
+ - pytorch
- defaults
dependencies:
- - ca-certificates
- - certifi
- - libcxx
- - libcxxabi
- - libedit
- - libffi
- - ncurses
- - openssl
- - pip=20.0.2=py38_1
- - python=3.8.1
- - readline
- - setuptools
- - sqlite=3.31.1
- - tk
- - wheel
- - xz
- - zlib
+ - blas=1.0=mkl
+ - ca-certificates=2021.10.26=haa95532_2
+ - certifi=2021.10.8=py38haa95532_0
+ - cudatoolkit=11.3.1=h59b6b97_2
+ - freetype=2.10.4=hd328e21_0
+ - intel-openmp=2021.4.0=haa95532_3556
+ - jpeg=9d=h2bbff1b_0
+ - libpng=1.6.37=h2a8f88b_0
+ - libtiff=4.2.0=hd0e1b90_0
+ - libuv=1.40.0=he774522_0
+ - libwebp=1.2.0=h2bbff1b_0
+ - lz4-c=1.9.3=h2bbff1b_1
+ - mkl=2021.4.0=haa95532_640
+ - mkl-service=2.4.0=py38h2bbff1b_0
+ - mkl_fft=1.3.1=py38h277e83a_0
+ - mkl_random=1.2.2=py38hf11a4ad_0
+ - numpy-base=1.21.2=py38h0829f74_0
+ - olefile=0.46=pyhd3eb1b0_0
+ - openssl=1.1.1l=h2bbff1b_0
+ - pillow=8.4.0=py38hd45dc43_0
+ - pip=21.2.2=py38haa95532_0
+ - python=3.8.1=h5fd99cc_8_cpython
+ - pytorch=1.10.1=py3.8_cuda11.3_cudnn8_0
+ - pytorch-mutex=1.0=cuda
+ - setuptools=58.0.4=py38haa95532_0
+ - six=1.16.0=pyhd3eb1b0_0
+ - sqlite=3.31.1=h2a8f88b_1
+ - tk=8.6.11=h2bbff1b_0
+ - torchaudio=0.10.1=py38_cu113
+ - torchvision=0.11.2=py38_cu113
+ - typing_extensions=3.10.0.2=pyh06a4308_0
+ - vc=14.2=h21ff451_1
+ - vs2015_runtime=14.27.29016=h5e58377_2
+ - wheel=0.37.0=pyhd3eb1b0_1
+ - wincertstore=0.2=py38haa95532_2
+ - xz=5.2.5=h62dcd97_0
+ - zlib=1.2.11=h8cc25b3_4
+ - zstd=1.4.9=h19a0ad4_0
- pip:
+ - absl-py==1.0.0
+ - altgraph==0.17.2
+ - cachetools==4.2.4
+ - charset-normalizer==2.0.9
+ - click==7.1.2
+ - colorama==0.4.4
+ - filelock==3.4.2
+ - future==0.18.2
+ - google-auth==1.35.0
+ - google-auth-oauthlib==0.4.6
+ - grpcio==1.43.0
+ - huggingface-hub==0.4.0
+ - idna==3.3
+ - importlib-metadata==4.10.0
+ - joblib==1.1.0
+ - markdown==3.3.6
- numpy==1.18.1
+ - oauthlib==3.1.1
- opencv-python==4.2.0.32
- - torch==1.4.0
+ - packaging==21.3
+ - pandas==1.3.5
+ - pefile==2021.9.3
+ - protobuf==3.19.1
+ - pyasn1==0.4.8
+ - pyasn1-modules==0.2.8
+ - pyinstaller==4.8
+ - pyinstaller-hooks-contrib==2021.5
+ - pyparsing==3.0.6
+ - pyqt5==5.15.6
+ - pyqt5-plugins==5.15.4.2.2
+ - pyqt5-qt5==5.15.2
+ - pyqt5-sip==12.9.0
+ - python-dateutil==2.8.2
+ - python-dotenv==0.19.2
+ - pytz==2021.3
+ - pywin32-ctypes==0.2.0
+ - pyyaml==6.0
+ - qt5-applications==5.15.2.2.2
+ - qt5-tools==5.15.2.1.2
+ - regex==2021.11.10
+ - requests==2.27.0
+ - requests-oauthlib==1.3.0
+ - rsa==4.8
+ - sacremoses==0.0.47
+ - scikit-learn==1.0.2
+ - scipy==1.7.3
+ - sklearn==0.0
- tensorboard==2.2.0
- tensorboard-plugin-wit==1.6.0.post2
-prefix: ~/miniconda3/envs/RN3
+ - threadpoolctl==3.0.0
+ - tokenizers==0.10.3
+ - torch-tb-profiler==0.3.1
+ - tqdm==4.62.3
+ - transformers==4.15.0
+ - typing-extensions==4.0.1
+ - urllib3==1.26.7
+ - werkzeug==2.0.2
+ - zipp==3.7.0
+prefix: D:\Coding\Anaconda3\envs\RN3
diff --git a/image_RGB.jpg b/image_RGB.jpg
new file mode 100644
index 0000000..401f471
Binary files /dev/null and b/image_RGB.jpg differ
diff --git a/main.py b/main.py
index b3535bd..0c98aed 100644
--- a/main.py
+++ b/main.py
@@ -20,32 +20,24 @@
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
-parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN',
- help='resume from model stored')
-parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
-parser.add_argument('--epochs', type=int, default=20, metavar='N',
- help='number of epochs to train (default: 20)')
-parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
- help='learning rate (default: 0.0001)')
-parser.add_argument('--no-cuda', action='store_true', default=False,
- help='disables CUDA training')
-parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
-parser.add_argument('--log-interval', type=int, default=10, metavar='N',
- help='how many batches to wait before logging training status')
-parser.add_argument('--resume', type=str,
- help='resume from model stored')
-parser.add_argument('--relation-type', type=str, default='binary',
- help='what kind of relations to learn. options: binary, ternary (default: binary)')
+parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN', help='resume from model stored')
+parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)')
+parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
+parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.0001)')
+parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
+parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
+parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')
+parser.add_argument('--resume', type=str, help='resume from model stored')
+parser.add_argument('--relation-type', type=str, default='binary', help='what kind of relations to learn. options: binary, ternary (default: binary)')
args = parser.parse_args()
-args.cuda = not args.no_cuda and torch.cuda.is_available()
+torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+args.cuda = True
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
-
+''''''
summary_writer = SummaryWriter()
if args.model=='CNN_MLP':
@@ -59,6 +51,7 @@
input_qst = torch.FloatTensor(bs, 18)
label = torch.LongTensor(bs)
+# print("gpu:", args.cuda)
if args.cuda:
model.cuda()
input_img = input_img.cuda()
@@ -70,6 +63,8 @@
label = Variable(label)
def tensor_data(data, i):
+
+ # [ batch_size*i : batch_size*i+1 ]
img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)]))
qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)]))
ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)]))
@@ -78,7 +73,6 @@ def tensor_data(data, i):
input_qst.data.resize_(qst.size()).copy_(qst)
label.data.resize_(ans.size()).copy_(ans)
-
def cvt_data_axis(data):
img = [e[0] for e in data]
qst = [e[1] for e in data]
@@ -97,6 +91,7 @@ def train(epoch, ternary, rel, norel):
random.shuffle(rel)
random.shuffle(norel)
+
ternary = cvt_data_axis(ternary)
rel = cvt_data_axis(rel)
norel = cvt_data_axis(norel)
@@ -224,6 +219,7 @@ def load_data():
filename = os.path.join(dirs,'sort-of-clevr.pickle')
with open(filename, 'rb') as f:
train_datasets, test_datasets = pickle.load(f)
+
ternary_train = []
ternary_test = []
rel_train = []
@@ -255,6 +251,7 @@ def load_data():
ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test = load_data()
+
try:
os.makedirs(model_dirs)
except:
@@ -276,11 +273,7 @@ def load_data():
print(f"Training {args.model} {f'({args.relation_type})' if args.model == 'RN' else ''} model...")
for epoch in range(1, args.epochs + 1):
- train_acc_ternary, train_acc_binary, train_acc_unary = train(
- epoch, ternary_train, rel_train, norel_train)
- test_acc_ternary, test_acc_binary, test_acc_unary = test(
- epoch, ternary_test, rel_test, norel_test)
-
- csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary,
- train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary])
- model.save_model(epoch)
+ train_acc_ternary, train_acc_binary, train_acc_unary = train(epoch, ternary_train, rel_train, norel_train)
+ test_acc_ternary, test_acc_binary, test_acc_unary = test(epoch, ternary_test, rel_test, norel_test)
+ csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary, train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary])
+ model.save_model(epoch)
\ No newline at end of file
diff --git a/model.py b/model.py
index f2bc9d2..1321f36 100644
--- a/model.py
+++ b/model.py
@@ -49,7 +49,8 @@ def forward(self, x):
x = F.relu(x)
x = F.dropout(x)
x = self.fc3(x)
- return F.log_softmax(x, dim=1)
+
+ return x
class BasicModel(nn.Module):
def __init__(self, args, name):
@@ -59,9 +60,10 @@ def __init__(self, args, name):
def train_(self, input_img, input_qst, label):
self.optimizer.zero_grad()
output = self(input_img, input_qst)
- loss = F.nll_loss(output, label)
+ loss = F.cross_entropy(output, label)
loss.backward()
self.optimizer.step()
+ output = F.log_softmax(output, dim=1)
pred = output.data.max(1)[1]
correct = pred.eq(label.data).cpu().sum()
accuracy = correct * 100. / len(label)
@@ -69,7 +71,8 @@ def train_(self, input_img, input_qst, label):
def test_(self, input_img, input_qst, label):
output = self(input_img, input_qst)
- loss = F.nll_loss(output, label)
+ loss = F.cross_entropy(output, label)
+ output = F.log_softmax(output, dim=1)
pred = output.data.max(1)[1]
correct = pred.eq(label.data).cpu().sum()
accuracy = correct * 100. / len(label)
@@ -100,11 +103,18 @@ def __init__(self, args):
self.f_fc1 = nn.Linear(256, 256)
+ self.fcout = FCOutputModel()
+
+ self.optimizer = optim.Adam(self.parameters(), lr=args.lr)
+
+
self.coord_oi = torch.FloatTensor(args.batch_size, 2)
self.coord_oj = torch.FloatTensor(args.batch_size, 2)
+
if args.cuda:
self.coord_oi = self.coord_oi.cuda()
self.coord_oj = self.coord_oj.cuda()
+
self.coord_oi = Variable(self.coord_oi)
self.coord_oj = Variable(self.coord_oj)
@@ -113,18 +123,19 @@ def cvt_coord(i):
return [(i/5-2)/2., (i%5-2)/2.]
self.coord_tensor = torch.FloatTensor(args.batch_size, 25, 2)
+
if args.cuda:
self.coord_tensor = self.coord_tensor.cuda()
+
self.coord_tensor = Variable(self.coord_tensor)
+
np_coord_tensor = np.zeros((args.batch_size, 25, 2))
+
for i in range(25):
- np_coord_tensor[:,i,:] = np.array( cvt_coord(i) )
- self.coord_tensor.data.copy_(torch.from_numpy(np_coord_tensor))
+ np_coord_tensor[:, i, :] = np.array(cvt_coord(i))
+ self.coord_tensor.data.copy_(torch.from_numpy(np_coord_tensor))
- self.fcout = FCOutputModel()
-
- self.optimizer = optim.Adam(self.parameters(), lr=args.lr)
def forward(self, img, qst):
@@ -134,13 +145,15 @@ def forward(self, img, qst):
mb = x.size()[0]
n_channels = x.size()[1]
d = x.size()[2]
+
+ # view -> (64 x 24 x 25)
+ # permute -> (64 x 25 x 24)
# x_flat = (64 x 25 x 24)
- x_flat = x.view(mb,n_channels,d*d).permute(0,2,1)
-
+ x_flat = x.view(mb, n_channels, d*d).permute(0,2,1)
# add coordinates
- x_flat = torch.cat([x_flat, self.coord_tensor],2)
+ x_flat = torch.cat([x_flat, self.coord_tensor], 2)
+
-
if self.relation_type == 'ternary':
# add question everywhere
qst = torch.unsqueeze(qst, 1) # (64x1x18)
@@ -166,12 +179,12 @@ def forward(self, img, qst):
x_full = torch.cat([x_i, x_j, x_k], 4) # (64x25x25x25x3*26+18)
# reshape for passing through network
- x_ = x_full.view(mb * (d * d) * (d * d) * (d * d), 96) # (64*25*25*25x3*26+18) = (1.000.000, 96)
+ x_ = x_full.view(mb * (d * d) * (d * d) * (d * d), 96) # (64*25*25*25x3*26+18) = (1,000,000, 96)
else:
# add question everywhere
- qst = torch.unsqueeze(qst, 1)
- qst = qst.repeat(1, 25, 1)
- qst = torch.unsqueeze(qst, 2)
+ qst = torch.unsqueeze(qst, 1) # (64x1x18)
+ qst = qst.repeat(1, 25, 1) # (64x25x18)
+ qst = torch.unsqueeze(qst, 2) # (64x25x1x18)
# cast all pairs against each other
x_i = torch.unsqueeze(x_flat, 1) # (64x1x25x26+18)
@@ -184,7 +197,7 @@ def forward(self, img, qst):
x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+18)
# reshape for passing through network
- x_ = x_full.view(mb * (d * d) * (d * d), 70) # (64*25*25x2*26*18) = (40.000, 70)
+ x_ = x_full.view(mb * (d * d) * (d * d), 70) # (64*25*25x2*26+18) = (40,000, 70)
x_ = self.g_fc1(x_)
x_ = F.relu(x_)
@@ -201,6 +214,7 @@ def forward(self, img, qst):
else:
x_g = x_.view(mb, (d * d) * (d * d), 256)
+
x_g = x_g.sum(1).squeeze()
"""f"""
@@ -219,7 +233,6 @@ def __init__(self, args):
self.fcout = FCOutputModel()
self.optimizer = optim.Adam(self.parameters(), lr=args.lr)
- #print([ a for a in self.parameters() ] )
def forward(self, img, qst):
x = self.conv(img) ## x = (64 x 24 x 5 x 5)
@@ -232,5 +245,4 @@ def forward(self, img, qst):
x_ = self.fc1(x_)
x_ = F.relu(x_)
- return self.fcout(x_)
-
+ return self.fcout(x_)
\ No newline at end of file
diff --git a/readme_img/binary_relational_acc.png b/readme_img/binary_relational_acc.png
new file mode 100644
index 0000000..1a6649f
Binary files /dev/null and b/readme_img/binary_relational_acc.png differ
diff --git a/readme_img/non_relational_acc.png b/readme_img/non_relational_acc.png
new file mode 100644
index 0000000..3c541a4
Binary files /dev/null and b/readme_img/non_relational_acc.png differ
diff --git a/readme_img/relational-network-application.gif b/readme_img/relational-network-application.gif
new file mode 100644
index 0000000..fec6351
Binary files /dev/null and b/readme_img/relational-network-application.gif differ
diff --git a/readme_img/ternary_relational_acc.png b/readme_img/ternary_relational_acc.png
new file mode 100644
index 0000000..272ea1a
Binary files /dev/null and b/readme_img/ternary_relational_acc.png differ
diff --git a/sort_of_clevr_generator.py b/sort_of_clevr_generator.py
index f3e3a12..b9b76db 100644
--- a/sort_of_clevr_generator.py
+++ b/sort_of_clevr_generator.py
@@ -2,21 +2,21 @@
import os
import numpy as np
import random
-#import cPickle as pickle
+
import pickle
import warnings
import argparse
+from translator import translate
parser = argparse.ArgumentParser(description='Sort-of-CLEVR dataset generator')
-parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
-parser.add_argument('--t-subtype', type=int, default=-1,
- help='Force ternary questions to be of a given type')
+parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
+parser.add_argument('--t-subtype', type=int, default=-1, help='Force ternary questions to be of a given type')
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
+
train_size = 9800
test_size = 200
img_size = 75
@@ -24,7 +24,18 @@
question_size = 18 ## 2 x (6 for one-hot vector of color), 3 for question type, 3 for question subtype
q_type_idx = 12
sub_q_type_idx = 15
-"""Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]"""
+
+'''
+Question [
+ obj 1 color: 'red':0, 'green':1, 'blue':2, 'orange':3, 'gray':4, 'yellow':5,
+ obj 2 color: 'red':6, 'green':7, 'blue':8, 'orange':9, 'gray':10, 'yellow':11,
+ 'no_rel':12, | 'rel':13, |'ternary':14,
+ 'sub-type[0]:15 query shape->rectangle/circle | closest-to->rectangle/circle | between->1~4',
+ 'sub-type[1]:16 query horizontal position->yes/no | furthest-from->rectangle/circle | is-on-band->yes/no',
+ 'sub-type[2]:17 query vertical position->yes/no | count->1~6 | count-obtuse-triangles->1~6'
+ ]
+Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]
+'''
nb_questions = 10
dirs = './data'
@@ -109,6 +120,7 @@ def build_dataset():
answer = 0
else:
answer = 1
+
norel_answers.append(answer)
"""Binary Relational questions"""
@@ -125,7 +137,7 @@ def build_dataset():
"""closest-to->rectangle/circle"""
my_obj = objects[color][1]
dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects]
- dist_list[dist_list.index(0)] = 999
+ dist_list[dist_list.index(0)] = 99999 #999
closest = dist_list.index(min(dist_list))
if objects[closest][2] == 'r':
answer = 2
@@ -258,6 +270,11 @@ def build_dataset():
norelations = (norel_questions, norel_answers)
img = img/255.
+
+ # save all image
+ # img_count = num_total
+ # cv2.imwrite(os.path.join(dirs+'/all dataset','{}.png'.format(img_count)), cv2.resize(img*255, (img_size,img_size)))
+
dataset = (img, ternary_relations, binary_relations, norelations)
return dataset
@@ -270,7 +287,7 @@ def build_dataset():
#img_count = 0
#cv2.imwrite(os.path.join(dirs,'{}.png'.format(img_count)), cv2.resize(train_datasets[0][0]*255, (512,512)))
-
+# translate(train_datasets[0])
print('saving datasets...')
filename = os.path.join(dirs,'sort-of-clevr.pickle')
diff --git a/translator.py b/translator.py
index ac3d242..755e5de 100644
--- a/translator.py
+++ b/translator.py
@@ -1,36 +1,47 @@
import cv2
+
def translate(dataset):
+
img, (rel_questions, rel_answers), (norel_questions, norel_answers) = dataset
colors = ['red ', 'green ', 'blue ', 'orange ', 'gray ', 'yellow ']
answer_sheet = ['yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6']
questions = rel_questions + norel_questions
answers = rel_answers + norel_answers
-
- print rel_questions
- print rel_answers
-
-
- for question,answer in zip(questions,answers):
+ '''
+ Question [
+ obj 1 color: 'red':0, 'green':1, 'blue':2, 'orange':3, 'gray':4, 'yellow':5,
+ obj 2 color: 'red':6, 'green':7, 'blue':8, 'orange':9, 'gray':10, 'yellow':11,
+ 'no_rel':12, | 'rel':13, |'ternary':14,
+ 'sub-type[0]:15 query shape->rectangle/circle | closest-to->rectangle/circle | between->1~4',
+ 'sub-type[1]:16 query horizontal position->yes/no | furthest-from->rectangle/circle | is-on-band->yes/no',
+ 'sub-type[2]:17 query vertical position->yes/no | count->1~6 | count-obtuse-triangles->1~6'
+ ] '''
+
+ ''' 3 for question type(index: 12:14), 3 for question subtype(index: 15:17) '''
+ for question, answer in zip(questions,answers):
+
query = ''
query += colors[question.tolist()[0:6].index(1)]
-
- if question[6] == 1:
- if question[8] == 1:
+
+ if question[12] == 1:
+ if question[15] == 1:
query += 'shape?'
- if question[9] == 1:
+ if question[16] == 1:
query += 'left?'
- if question[10] == 1:
+ if question[17] == 1:
query += 'up?'
- if question[7] == 1:
- if question[8] == 1:
+
+ if question[13] == 1:
+ if question[15] == 1:
query += 'closest shape?'
- if question[9] == 1:
+ if question[16] == 1:
query += 'furthest shape?'
- if question[10] == 1:
+ if question[17] == 1:
query += 'count?'
ans = answer_sheet[answer]
- print query,'==>', ans
+ print(query,'==>', ans)
+
#cv2.imwrite('sample.jpg',(img*255).astype(np.int32))
cv2.imshow('img',cv2.resize(img,(512,512)))
cv2.waitKey(0)
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..0605b28
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,2 @@
+from .project_ui import *
+from .preditor_util import *
\ No newline at end of file
diff --git a/utils/preditor_util.py b/utils/preditor_util.py
new file mode 100644
index 0000000..58a37d8
--- /dev/null
+++ b/utils/preditor_util.py
@@ -0,0 +1,115 @@
+from model import RN
+import cv2
+import string
+import numpy as np
+import torch
+import torch.nn.functional as F
+import argparse
+
+parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
+parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN', help='resume from model stored')
+parser.add_argument('--batch-size', type=int, default=1, metavar='N', help='input batch size for training (default: 64)')
+parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
+parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.0001)')
+parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
+parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
+parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')
+parser.add_argument('--resume', type=str, help='resume from model stored')
+parser.add_argument('--relation-type', type=str, default='binary', help='what kind of relations to learn. options: binary, ternary (default: binary)')
+
+args = parser.parse_args()
+torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+args.cuda = False
+
+class RNPredictor():
+ def __init__(self):
+ pass
+
+ def __tensor_data(self, data):
+ image, question = data
+ image = torch.FloatTensor(image).unsqueeze(0)
+ image = image.permute(0, 3, 1, 2)
+ question = torch.FloatTensor([question])
+ return image, question
+
+ def __map_qustion_type(self, idx):
+ question_idx = 12
+ question_sub_type = 15
+ if idx == 1:
+ question_idx += 0
+ question_sub_type += 0
+ elif idx == 2:
+ question_sub_type += 1
+ elif idx == 3:
+ question_sub_type += 2
+ elif idx == 4:
+ question_idx += 1
+ question_sub_type += 0
+ elif idx == 5:
+ question_idx += 1
+ question_sub_type += 1
+ elif idx == 6:
+ question_idx += 1
+ question_sub_type += 2
+ return question_idx, question_sub_type
+
+ def __convert_to_vector(self, question_idx, tokens):
+ color_map = { 'red':0, 'green':1, 'blue':2, 'orange':3, 'gray':4, 'yellow':5 }
+ question = [ 0 for i in range(18) ]
+ for word in tokens:
+ if word in color_map.keys():
+ question[color_map[word]] = 1
+ question_idx, question_sub_type = self.__map_qustion_type(question_idx)
+ # print(question_idx, question_sub_type)
+ question[question_idx] = 1
+ question[question_sub_type] = 1
+ return question
+
+ def tokenize(self, sentence):
+ type_split = sentence.split('.')
+ question_idx = int(type_split[0])
+ sentence = type_split[1]
+ for c in sentence:
+ if c in string.punctuation:
+ no_punctuation_sentence = sentence.replace(c, '')
+ tokens = no_punctuation_sentence.split()
+ question = self.__convert_to_vector(question_idx, tokens)
+ # print(question)
+ return question
+
+ def predict(self, data):
+ classes = [ 'yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6' ]
+ model = RN(args)
+ checkpoint = torch.load('best_model/epoch_RN_20.pth')
+ model.load_state_dict(checkpoint)
+ model.eval()
+ input_image, input_question = self.__tensor_data(data)
+ output = model(input_image, input_question)
+ output = F.log_softmax(output, dim=0)
+ _, prediction = output.data.max(0)
+ # print(_, prediction)
+ return classes[prediction]
+
+'''
+# testing
+if __name__ == '__main__':
+ image = cv2.imread('image_RGB.jpg')
+ colors = [ 'red', 'green', 'blue', 'orange', 'gray', 'yellow' ]
+
+ questions = [
+ "1.What is the shape of the object ?",
+ "2.Is object placed on the left side of the image ?",
+ "3.Is object placed on the upside of the image ?",
+ "4.What is the shape of the object closest to the object ?",
+ "5.What is the shape of the object furthest to the object ?",
+ "6.How many objects have same shape with the object ?" ]
+
+ rn_predictor = RNPredictor()
+ for quetion in questions:
+ for color in colors:
+ sentence = quetion.replace('', color)
+ print('Question:', sentence)
+ question = rn_predictor.tokenize(sentence)
+ result = rn_predictor.predict((image/255, question))
+ print('Answer:', result)
+'''
\ No newline at end of file
diff --git a/utils/project_ui.py b/utils/project_ui.py
new file mode 100644
index 0000000..4934ce9
--- /dev/null
+++ b/utils/project_ui.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'project_UI.ui'
+#
+# Created by: PyQt5 UI code generator 5.9.2
+#
+# WARNING! All changes made in this file will be lost!
+from PyQt5 import QtCore, QtWidgets,QtGui
+from PyQt5.QtGui import QFont as qfont
+from .shape_util import ShapeCanvas
+
+class Ui_MainWindow(object):
+ def setupUi(self, MainWindow):
+ MainWindow.setObjectName("MainWindow")
+ MainWindow.resize(740, 804)
+ self.setStyleSheet("background-color: rgb(224, 227, 210);")
+ self.centralwidget = QtWidgets.QWidget(MainWindow)
+ self.centralwidget.setObjectName("centralwidget")
+ self.pushButton = QtWidgets.QPushButton(self.centralwidget)
+ self.pushButton.setGeometry(QtCore.QRect(630, 650, 81, 31))
+ self.pushButton.setStyleSheet("color: rgb(5,22,39)")
+ self.pushButton.setFont(qfont('Times', 10))
+ self.pushButton.setObjectName("pushButton")
+ self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
+ self.pushButton_2.setGeometry(QtCore.QRect(630, 720, 81, 31))
+ self.pushButton_2.setStyleSheet("color: rgb(5,22,39)")
+ self.pushButton_2.setFont(qfont('Times', 10))
+ self.pushButton_2.setObjectName("pushButton_2")
+
+ self.label = ShapeCanvas(self)
+ self.label.setGeometry(QtCore.QRect(70, 20, 600, 600))
+ self.label.setAutoFillBackground(True)
+ self.label.setText("")
+ self.label.setObjectName("label")
+
+ self.label2 = QtWidgets.QLabel(self.centralwidget)
+ self.label2.setGeometry(QtCore.QRect(10, 630, 80, 15))
+ self.label2.setAutoFillBackground(True)
+ self.label2.setStyleSheet("color: rgb(5,22,39)")
+ self.label2.setFont(qfont('Times', 10 ,QtGui.QFont.Bold))
+ self.label2.setText("Question:")
+ self.label2.setObjectName("label2")
+
+ self.label3 = QtWidgets.QLabel(self.centralwidget)
+ self.label3.setGeometry(QtCore.QRect(10, 720, 70, 31))
+ self.label3.setAutoFillBackground(True)
+ self.label3.setStyleSheet("color: rgb(5,22,39)")
+ self.label3.setFont(qfont('Times', 10 ,QtGui.QFont.Bold))
+ self.label3.setText("Answer:")
+ self.label3.setObjectName("label3")
+
+ self.comboBox = QtWidgets.QComboBox(self.centralwidget)
+ self.comboBox.setEnabled(True)
+ self.comboBox.setGeometry(QtCore.QRect(30, 650, 571, 60))
+ self.comboBox.setStyleSheet("color: rgb(5,22,39)")
+ self.comboBox.setFont(qfont('Times', 12))
+ self.comboBox.setEditable(True)
+ self.comboBox.lineEdit().setFont(qfont('Times', 12, QtGui.QFont.Bold))
+ self.comboBox.setObjectName("comboBox")
+ self.comboBox.addItem("1.What is the shape of the object ?")
+ self.comboBox.addItem("2.Is object placed on the left side of the image ?")
+ self.comboBox.addItem("3.Is object placed on the up side of the image ?")
+ self.comboBox.addItem("4.What is the shape of the object closest to the object ?")
+ self.comboBox.addItem("5.What is the shape of the object furthest to the object ?")
+ self.comboBox.addItem("6.How many objects have same shape with the object ?")
+ '''
+ self.textEdit = QtWidgets.QTextEdit(self.centralwidget)
+ self.textEdit.setStyleSheet("color: rgb(5,22,39)")
+ self.textEdit.setFont(qfont('Times', 12))
+ self.textEdit.setGeometry(QtCore.QRect(30, 650, 571, 60))
+ self.textEdit.setObjectName("textEdit")
+ '''
+ self.label4 = QtWidgets.QLabel(self.centralwidget)
+ self.label4.setStyleSheet("color: rgb(5,22,39)")
+ self.label4.setFont(qfont('Times', 12))
+ self.label4.setGeometry(QtCore.QRect(100, 720, 500, 31))
+ self.label4.setObjectName("label4")
+ # test
+ # self.label4.setText("true")
+ self.label4.setText("")
+ MainWindow.setCentralWidget(self.centralwidget)
+ self.menubar = QtWidgets.QMenuBar(MainWindow)
+ self.menubar.setGeometry(QtCore.QRect(0, 0, 852, 26))
+ self.menubar.setObjectName("menubar")
+ MainWindow.setMenuBar(self.menubar)
+ self.statusbar = QtWidgets.QStatusBar(MainWindow)
+ self.statusbar.setObjectName("statusbar")
+ MainWindow.setStatusBar(self.statusbar)
+
+ self.retranslateUi(MainWindow)
+ QtCore.QMetaObject.connectSlotsByName(MainWindow)
+
+ def retranslateUi(self, MainWindow):
+ _translate = QtCore.QCoreApplication.translate
+ MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
+ self.pushButton.setText(_translate("MainWindow", "OK"))
+ self.pushButton_2.setText(_translate("MainWindow", "Random"))
+
diff --git a/utils/shape_util.py b/utils/shape_util.py
new file mode 100644
index 0000000..a597525
--- /dev/null
+++ b/utils/shape_util.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Thu Jan 13 18:03:34 2022
+
+@author: helen
+"""
+
+from PyQt5 import QtWidgets
+from PyQt5.QtCore import QRect, Qt
+from PyQt5.QtGui import QPainter, QPen, QColor, QPixmap
+import random
+
+class Shape():
+ def __init__(self):
+ self.object_shape = []
+ self.object_x = [ 17, 114, 211, 308, 405, 502 ]
+ self.object_y = [ 500, 500, 500, 500, 500, 500 ]
+ self.object_color = [ "#FF0000", # r
+ "#00FF00", # g
+ "#0000FF", # b
+ "#ff9d00", # o
+ "#808080", # k
+ "#FFFF00" ] # y
+
+ def random_shape(self):
+ random.shuffle(self.object_color)
+ self.object_shape.clear()
+ for i in range(6):
+ self.object_shape.append(random.choice(['rectangle', 'circle']))
+
+class ShapeCanvas(QtWidgets.QLabel):
+ def __init__(self, parent=None):
+ super().__init__(parent)
+
+ self.canvas = QPixmap(600, 600)
+ self.canvas.fill(QColor("white"))
+ self.setPixmap(self.canvas)
+
+ self.shape = Shape()
+ self.shape.random_shape()
+ self.shape_size = 80
+ self.drag_idx = -1
+
+ def paintEvent(self, event):
+ super(ShapeCanvas, self).paintEvent(event)
+ painter = QPainter(self)
+ painter.setRenderHint(QPainter.Antialiasing)
+
+ for i in range(6):
+ color = QColor(self.shape.object_color[i])
+ pen = QPen(color, 1)
+ painter.setBrush(color)
+ painter.setPen(pen)
+ if self.shape.object_shape[i] == 'rectangle':
+ painter.drawRect(QRect(self.shape.object_x[i],self.shape.object_y[i],self.shape_size,self.shape_size))
+ else:
+ painter.drawEllipse(QRect(self.shape.object_x[i],self.shape.object_y[i],self.shape_size,self.shape_size))
+
+ def mousePressEvent(self, event):
+ if event.button() == Qt.LeftButton:
+ self.drag_idx = self.drag_index(event.pos())
+ # print('drag_index:', self.drag_idx)
+
+ def drag_index(self, position):
+ for i in range(6):
+ if self.is_pointSelected(self.shape.object_x[i],self.shape.object_y[i], position):
+ return i
+ return -1
+
+ def is_pointSelected(self, point_x ,point_y, position):
+ # point range
+ x_min =point_x - self.shape_size
+ x_max = point_x + self.shape_size
+ y_min = point_y - self.shape_size
+ y_max = point_y + self.shape_size
+
+ # check control points in select range position
+ x, y = position.x(), position.y()
+ if x_min < x < x_max and y_min < y < y_max:
+ return True
+ return False
+
+ def mouseMoveEvent(self, event):
+ if self.drag_idx != -1:
+ self.shape.object_x[self.drag_idx] = event.pos().x()
+ self.shape.object_y[self.drag_idx] = event.pos().y()
+ self.update()
+
+ def mouseReleaseEvent(self, event):
+ self.drag_idx = -1
+ self.update()
+
+
+ def pushButton2_clicked(self): # Random
+ self.shape.random_shape()
+ self.repaint()
+
+
+
+
\ No newline at end of file