diff --git a/README.md b/README.md index 49717fe..542f614 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,23 @@ CNN + MLP occured overfitting to the training data. Relational networks shows far better results in relational questions and non-relation questions. +## Application Demo +retrain experiments: + +![](./readme_img/binary_relational_acc.png) +![](./readme_img/non_relational_acc.png) +![](./readme_img/ternary_relational_acc.png) + + +You can randomly generate and move 2D shaped objects and edit text to ask questions. + + $ python application.py + + + + ## Contributions [@gngdb](https://github.com/gngdb) speeds up the model by 10 times. + +[@neural022](https://github.com/neural022) and [@hhhlll21qq](https://github.com/hhhlll21qq) build application. diff --git a/RN_1_log.csv b/RN_1_log.csv new file mode 100644 index 0000000..13da807 --- /dev/null +++ b/RN_1_log.csv @@ -0,0 +1,21 @@ +epoch,train_acc_ternary,train_acc_rel,train_acc_norel,train_acc_ternary,test_acc_rel,test_acc_norel +1,51.555355976485956,50.55825440888308,52.975996080992815,54.13306451612903,61.29032258064516,55.69556451612903 +2,53.33830013063357,66.2342831482691,55.337606139777925,52.368951612903224,72.88306451612904,56.703629032258064 +3,53.53833278902678,69.31539843239713,56.32654310907903,51.965725806451616,72.22782258064517,58.71975806451613 +4,53.70672762900065,69.78996570868713,58.261552906597,52.167338709677416,73.38709677419355,61.44153225806452 +5,53.97615937295885,70.14104343566297,60.37924559111692,52.318548387096776,72.53024193548387,63.457661290322584 +6,54.60687459177009,70.68807152188113,61.68456074461137,53.024193548387096,71.52217741935483,64.21370967741936 +7,55.12634715871979,70.62377531025473,62.19382756368387,52.671370967741936,72.58064516129032,63.608870967741936 +8,55.39782005225343,71.11058948399739,62.6500244937949,53.931451612903224,74.24395161290323,63.76008064516129 +9,55.49477465708687,71.87704114957545,63.30013063357283,54.48588709677419,74.04233870967742,66.22983870967742 +10,55.601935009797515,72.61491672109732,66.2832707380797,53.42741935483871,74.49596774193549,71.27016129032258 +11,55.741753755715216,73.1078543435663,77.11258981058133,53.78024193548387,74.44556451612904,86.18951612903226 +12,55.741753755715216,73.47015839320706,93.81633736120183,54.435483870967744,76.00806451612904,98.03427419354838 +13,55.940765839320704,76.68394839973874,98.29564010450686,54.939516129032256,80.54435483870968,98.5383064516129 +14,56.174477465708684,82.49101894186806,99.06107119529719,53.881048387096776,84.77822580645162,97.88306451612904 +15,56.41125081645983,84.75771554539517,99.32744121489223,55.39314516129032,85.38306451612904,98.7399193548387 +16,56.56841933376878,86.46513716525146,99.5009389288047,55.292338709677416,85.6350806451613,98.94153225806451 +17,56.68068256041803,87.47448563030699,99.63463422599608,56.30040322580645,87.3991935483871,98.99193548387096 +18,57.09911822338341,88.2388961463096,99.67851894186806,55.645161290322584,87.3991935483871,99.14314516129032 +19,57.25730731548008,88.99208033964729,99.7305682560418,56.149193548387096,88.20564516129032,99.19354838709677 +20,57.507348138471585,89.61258981058133,99.7601649248857,56.09879032258065,88.10483870967742,98.89112903225806 diff --git a/application.py b/application.py new file mode 100644 index 0000000..bff7e54 --- /dev/null +++ b/application.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Jan 18 10:29:41 2022 + +@author: helen, george chen(neural022) +""" + +import sys +import cv2 +from PyQt5 import QtWidgets +from utils import Ui_MainWindow, RNPredictor +from PIL import ImageQt,Image +from numpy import asarray +import time + + +class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): + def __init__(self, parent=None): + super(MainWindow, self).__init__(parent=parent) + self.setupUi(self) + self.rn_predictor = RNPredictor() + + self.pushButton.clicked.connect(self.pushButton_clicked) + self.pushButton_2.clicked.connect(self.label.pushButton2_clicked) + + def pushButton_clicked(self):#OK + label_image = ImageQt.fromqpixmap(self.label.grab()) + image_RGB = Image.new("RGB", label_image.size, (255, 255, 255)) + image_RGB.paste(label_image, mask=label_image.split()[3]) + image_RGB = image_RGB.resize((75, 75), Image.ANTIALIAS) + image_RGB.save('image_RGB.jpg', 'JPEG', quality=100) + img_rgb_array = asarray(image_RGB) + img_bgr_array = cv2.cvtColor(img_rgb_array, cv2.COLOR_RGB2BGR) + # print(img_bgr_array.shape)#(75,75,3) + time.sleep(0.5) + # Question ComboBox + # Answer label4 + if self.comboBox.currentText() != '': + self.question = self.comboBox.currentText() + print('Question:', self.question) + question = self.rn_predictor.tokenize(self.question) + self.answer = self.rn_predictor.predict((img_bgr_array/255, question)) + print('Answer:', self.answer) + self.label4.setText(self.answer) + + +if __name__ == "__main__": + app = QtWidgets.QApplication([]) + app.setStyle('Fusion') + w = MainWindow() + w.show() + + sys.exit(app.exec()) + \ No newline at end of file diff --git a/best_model/epoch_RN_20.pth b/best_model/epoch_RN_20.pth new file mode 100644 index 0000000..90117d8 Binary files /dev/null and b/best_model/epoch_RN_20.pth differ diff --git a/environment.yml b/environment.yml index b3ac08c..9d708ea 100644 --- a/environment.yml +++ b/environment.yml @@ -1,28 +1,103 @@ name: RN3 channels: + - pytorch - defaults dependencies: - - ca-certificates - - certifi - - libcxx - - libcxxabi - - libedit - - libffi - - ncurses - - openssl - - pip=20.0.2=py38_1 - - python=3.8.1 - - readline - - setuptools - - sqlite=3.31.1 - - tk - - wheel - - xz - - zlib + - blas=1.0=mkl + - ca-certificates=2021.10.26=haa95532_2 + - certifi=2021.10.8=py38haa95532_0 + - cudatoolkit=11.3.1=h59b6b97_2 + - freetype=2.10.4=hd328e21_0 + - intel-openmp=2021.4.0=haa95532_3556 + - jpeg=9d=h2bbff1b_0 + - libpng=1.6.37=h2a8f88b_0 + - libtiff=4.2.0=hd0e1b90_0 + - libuv=1.40.0=he774522_0 + - libwebp=1.2.0=h2bbff1b_0 + - lz4-c=1.9.3=h2bbff1b_1 + - mkl=2021.4.0=haa95532_640 + - mkl-service=2.4.0=py38h2bbff1b_0 + - mkl_fft=1.3.1=py38h277e83a_0 + - mkl_random=1.2.2=py38hf11a4ad_0 + - numpy-base=1.21.2=py38h0829f74_0 + - olefile=0.46=pyhd3eb1b0_0 + - openssl=1.1.1l=h2bbff1b_0 + - pillow=8.4.0=py38hd45dc43_0 + - pip=21.2.2=py38haa95532_0 + - python=3.8.1=h5fd99cc_8_cpython + - pytorch=1.10.1=py3.8_cuda11.3_cudnn8_0 + - pytorch-mutex=1.0=cuda + - setuptools=58.0.4=py38haa95532_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.31.1=h2a8f88b_1 + - tk=8.6.11=h2bbff1b_0 + - torchaudio=0.10.1=py38_cu113 + - torchvision=0.11.2=py38_cu113 + - typing_extensions=3.10.0.2=pyh06a4308_0 + - vc=14.2=h21ff451_1 + - vs2015_runtime=14.27.29016=h5e58377_2 + - wheel=0.37.0=pyhd3eb1b0_1 + - wincertstore=0.2=py38haa95532_2 + - xz=5.2.5=h62dcd97_0 + - zlib=1.2.11=h8cc25b3_4 + - zstd=1.4.9=h19a0ad4_0 - pip: + - absl-py==1.0.0 + - altgraph==0.17.2 + - cachetools==4.2.4 + - charset-normalizer==2.0.9 + - click==7.1.2 + - colorama==0.4.4 + - filelock==3.4.2 + - future==0.18.2 + - google-auth==1.35.0 + - google-auth-oauthlib==0.4.6 + - grpcio==1.43.0 + - huggingface-hub==0.4.0 + - idna==3.3 + - importlib-metadata==4.10.0 + - joblib==1.1.0 + - markdown==3.3.6 - numpy==1.18.1 + - oauthlib==3.1.1 - opencv-python==4.2.0.32 - - torch==1.4.0 + - packaging==21.3 + - pandas==1.3.5 + - pefile==2021.9.3 + - protobuf==3.19.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyinstaller==4.8 + - pyinstaller-hooks-contrib==2021.5 + - pyparsing==3.0.6 + - pyqt5==5.15.6 + - pyqt5-plugins==5.15.4.2.2 + - pyqt5-qt5==5.15.2 + - pyqt5-sip==12.9.0 + - python-dateutil==2.8.2 + - python-dotenv==0.19.2 + - pytz==2021.3 + - pywin32-ctypes==0.2.0 + - pyyaml==6.0 + - qt5-applications==5.15.2.2.2 + - qt5-tools==5.15.2.1.2 + - regex==2021.11.10 + - requests==2.27.0 + - requests-oauthlib==1.3.0 + - rsa==4.8 + - sacremoses==0.0.47 + - scikit-learn==1.0.2 + - scipy==1.7.3 + - sklearn==0.0 - tensorboard==2.2.0 - tensorboard-plugin-wit==1.6.0.post2 -prefix: ~/miniconda3/envs/RN3 + - threadpoolctl==3.0.0 + - tokenizers==0.10.3 + - torch-tb-profiler==0.3.1 + - tqdm==4.62.3 + - transformers==4.15.0 + - typing-extensions==4.0.1 + - urllib3==1.26.7 + - werkzeug==2.0.2 + - zipp==3.7.0 +prefix: D:\Coding\Anaconda3\envs\RN3 diff --git a/image_RGB.jpg b/image_RGB.jpg new file mode 100644 index 0000000..401f471 Binary files /dev/null and b/image_RGB.jpg differ diff --git a/main.py b/main.py index b3535bd..0c98aed 100644 --- a/main.py +++ b/main.py @@ -20,32 +20,24 @@ # Training settings parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example') -parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN', - help='resume from model stored') -parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') -parser.add_argument('--epochs', type=int, default=20, metavar='N', - help='number of epochs to train (default: 20)') -parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', - help='learning rate (default: 0.0001)') -parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') -parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') -parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') -parser.add_argument('--resume', type=str, - help='resume from model stored') -parser.add_argument('--relation-type', type=str, default='binary', - help='what kind of relations to learn. options: binary, ternary (default: binary)') +parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN', help='resume from model stored') +parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') +parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)') +parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.0001)') +parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') +parser.add_argument('--resume', type=str, help='resume from model stored') +parser.add_argument('--relation-type', type=str, default='binary', help='what kind of relations to learn. options: binary, ternary (default: binary)') args = parser.parse_args() -args.cuda = not args.no_cuda and torch.cuda.is_available() +torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +args.cuda = True torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) - +'''''' summary_writer = SummaryWriter() if args.model=='CNN_MLP': @@ -59,6 +51,7 @@ input_qst = torch.FloatTensor(bs, 18) label = torch.LongTensor(bs) +# print("gpu:", args.cuda) if args.cuda: model.cuda() input_img = input_img.cuda() @@ -70,6 +63,8 @@ label = Variable(label) def tensor_data(data, i): + + # [ batch_size*i : batch_size*i+1 ] img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)])) qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)])) ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)])) @@ -78,7 +73,6 @@ def tensor_data(data, i): input_qst.data.resize_(qst.size()).copy_(qst) label.data.resize_(ans.size()).copy_(ans) - def cvt_data_axis(data): img = [e[0] for e in data] qst = [e[1] for e in data] @@ -97,6 +91,7 @@ def train(epoch, ternary, rel, norel): random.shuffle(rel) random.shuffle(norel) + ternary = cvt_data_axis(ternary) rel = cvt_data_axis(rel) norel = cvt_data_axis(norel) @@ -224,6 +219,7 @@ def load_data(): filename = os.path.join(dirs,'sort-of-clevr.pickle') with open(filename, 'rb') as f: train_datasets, test_datasets = pickle.load(f) + ternary_train = [] ternary_test = [] rel_train = [] @@ -255,6 +251,7 @@ def load_data(): ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test = load_data() + try: os.makedirs(model_dirs) except: @@ -276,11 +273,7 @@ def load_data(): print(f"Training {args.model} {f'({args.relation_type})' if args.model == 'RN' else ''} model...") for epoch in range(1, args.epochs + 1): - train_acc_ternary, train_acc_binary, train_acc_unary = train( - epoch, ternary_train, rel_train, norel_train) - test_acc_ternary, test_acc_binary, test_acc_unary = test( - epoch, ternary_test, rel_test, norel_test) - - csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary, - train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary]) - model.save_model(epoch) + train_acc_ternary, train_acc_binary, train_acc_unary = train(epoch, ternary_train, rel_train, norel_train) + test_acc_ternary, test_acc_binary, test_acc_unary = test(epoch, ternary_test, rel_test, norel_test) + csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary, train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary]) + model.save_model(epoch) \ No newline at end of file diff --git a/model.py b/model.py index f2bc9d2..1321f36 100644 --- a/model.py +++ b/model.py @@ -49,7 +49,8 @@ def forward(self, x): x = F.relu(x) x = F.dropout(x) x = self.fc3(x) - return F.log_softmax(x, dim=1) + + return x class BasicModel(nn.Module): def __init__(self, args, name): @@ -59,9 +60,10 @@ def __init__(self, args, name): def train_(self, input_img, input_qst, label): self.optimizer.zero_grad() output = self(input_img, input_qst) - loss = F.nll_loss(output, label) + loss = F.cross_entropy(output, label) loss.backward() self.optimizer.step() + output = F.log_softmax(output, dim=1) pred = output.data.max(1)[1] correct = pred.eq(label.data).cpu().sum() accuracy = correct * 100. / len(label) @@ -69,7 +71,8 @@ def train_(self, input_img, input_qst, label): def test_(self, input_img, input_qst, label): output = self(input_img, input_qst) - loss = F.nll_loss(output, label) + loss = F.cross_entropy(output, label) + output = F.log_softmax(output, dim=1) pred = output.data.max(1)[1] correct = pred.eq(label.data).cpu().sum() accuracy = correct * 100. / len(label) @@ -100,11 +103,18 @@ def __init__(self, args): self.f_fc1 = nn.Linear(256, 256) + self.fcout = FCOutputModel() + + self.optimizer = optim.Adam(self.parameters(), lr=args.lr) + + self.coord_oi = torch.FloatTensor(args.batch_size, 2) self.coord_oj = torch.FloatTensor(args.batch_size, 2) + if args.cuda: self.coord_oi = self.coord_oi.cuda() self.coord_oj = self.coord_oj.cuda() + self.coord_oi = Variable(self.coord_oi) self.coord_oj = Variable(self.coord_oj) @@ -113,18 +123,19 @@ def cvt_coord(i): return [(i/5-2)/2., (i%5-2)/2.] self.coord_tensor = torch.FloatTensor(args.batch_size, 25, 2) + if args.cuda: self.coord_tensor = self.coord_tensor.cuda() + self.coord_tensor = Variable(self.coord_tensor) + np_coord_tensor = np.zeros((args.batch_size, 25, 2)) + for i in range(25): - np_coord_tensor[:,i,:] = np.array( cvt_coord(i) ) - self.coord_tensor.data.copy_(torch.from_numpy(np_coord_tensor)) + np_coord_tensor[:, i, :] = np.array(cvt_coord(i)) + self.coord_tensor.data.copy_(torch.from_numpy(np_coord_tensor)) - self.fcout = FCOutputModel() - - self.optimizer = optim.Adam(self.parameters(), lr=args.lr) def forward(self, img, qst): @@ -134,13 +145,15 @@ def forward(self, img, qst): mb = x.size()[0] n_channels = x.size()[1] d = x.size()[2] + + # view -> (64 x 24 x 25) + # permute -> (64 x 25 x 24) # x_flat = (64 x 25 x 24) - x_flat = x.view(mb,n_channels,d*d).permute(0,2,1) - + x_flat = x.view(mb, n_channels, d*d).permute(0,2,1) # add coordinates - x_flat = torch.cat([x_flat, self.coord_tensor],2) + x_flat = torch.cat([x_flat, self.coord_tensor], 2) + - if self.relation_type == 'ternary': # add question everywhere qst = torch.unsqueeze(qst, 1) # (64x1x18) @@ -166,12 +179,12 @@ def forward(self, img, qst): x_full = torch.cat([x_i, x_j, x_k], 4) # (64x25x25x25x3*26+18) # reshape for passing through network - x_ = x_full.view(mb * (d * d) * (d * d) * (d * d), 96) # (64*25*25*25x3*26+18) = (1.000.000, 96) + x_ = x_full.view(mb * (d * d) * (d * d) * (d * d), 96) # (64*25*25*25x3*26+18) = (1,000,000, 96) else: # add question everywhere - qst = torch.unsqueeze(qst, 1) - qst = qst.repeat(1, 25, 1) - qst = torch.unsqueeze(qst, 2) + qst = torch.unsqueeze(qst, 1) # (64x1x18) + qst = qst.repeat(1, 25, 1) # (64x25x18) + qst = torch.unsqueeze(qst, 2) # (64x25x1x18) # cast all pairs against each other x_i = torch.unsqueeze(x_flat, 1) # (64x1x25x26+18) @@ -184,7 +197,7 @@ def forward(self, img, qst): x_full = torch.cat([x_i,x_j],3) # (64x25x25x2*26+18) # reshape for passing through network - x_ = x_full.view(mb * (d * d) * (d * d), 70) # (64*25*25x2*26*18) = (40.000, 70) + x_ = x_full.view(mb * (d * d) * (d * d), 70) # (64*25*25x2*26+18) = (40,000, 70) x_ = self.g_fc1(x_) x_ = F.relu(x_) @@ -201,6 +214,7 @@ def forward(self, img, qst): else: x_g = x_.view(mb, (d * d) * (d * d), 256) + x_g = x_g.sum(1).squeeze() """f""" @@ -219,7 +233,6 @@ def __init__(self, args): self.fcout = FCOutputModel() self.optimizer = optim.Adam(self.parameters(), lr=args.lr) - #print([ a for a in self.parameters() ] ) def forward(self, img, qst): x = self.conv(img) ## x = (64 x 24 x 5 x 5) @@ -232,5 +245,4 @@ def forward(self, img, qst): x_ = self.fc1(x_) x_ = F.relu(x_) - return self.fcout(x_) - + return self.fcout(x_) \ No newline at end of file diff --git a/readme_img/binary_relational_acc.png b/readme_img/binary_relational_acc.png new file mode 100644 index 0000000..1a6649f Binary files /dev/null and b/readme_img/binary_relational_acc.png differ diff --git a/readme_img/non_relational_acc.png b/readme_img/non_relational_acc.png new file mode 100644 index 0000000..3c541a4 Binary files /dev/null and b/readme_img/non_relational_acc.png differ diff --git a/readme_img/relational-network-application.gif b/readme_img/relational-network-application.gif new file mode 100644 index 0000000..fec6351 Binary files /dev/null and b/readme_img/relational-network-application.gif differ diff --git a/readme_img/ternary_relational_acc.png b/readme_img/ternary_relational_acc.png new file mode 100644 index 0000000..272ea1a Binary files /dev/null and b/readme_img/ternary_relational_acc.png differ diff --git a/sort_of_clevr_generator.py b/sort_of_clevr_generator.py index f3e3a12..b9b76db 100644 --- a/sort_of_clevr_generator.py +++ b/sort_of_clevr_generator.py @@ -2,21 +2,21 @@ import os import numpy as np import random -#import cPickle as pickle + import pickle import warnings import argparse +from translator import translate parser = argparse.ArgumentParser(description='Sort-of-CLEVR dataset generator') -parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') -parser.add_argument('--t-subtype', type=int, default=-1, - help='Force ternary questions to be of a given type') +parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') +parser.add_argument('--t-subtype', type=int, default=-1, help='Force ternary questions to be of a given type') args = parser.parse_args() random.seed(args.seed) np.random.seed(args.seed) + train_size = 9800 test_size = 200 img_size = 75 @@ -24,7 +24,18 @@ question_size = 18 ## 2 x (6 for one-hot vector of color), 3 for question type, 3 for question subtype q_type_idx = 12 sub_q_type_idx = 15 -"""Answer : [yes, no, rectangle, circle, r, g, b, o, k, y]""" + +''' +Question [ + obj 1 color: 'red':0, 'green':1, 'blue':2, 'orange':3, 'gray':4, 'yellow':5, + obj 2 color: 'red':6, 'green':7, 'blue':8, 'orange':9, 'gray':10, 'yellow':11, + 'no_rel':12, | 'rel':13, |'ternary':14, + 'sub-type[0]:15 query shape->rectangle/circle | closest-to->rectangle/circle | between->1~4', + 'sub-type[1]:16 query horizontal position->yes/no | furthest-from->rectangle/circle | is-on-band->yes/no', + 'sub-type[2]:17 query vertical position->yes/no | count->1~6 | count-obtuse-triangles->1~6' + ] +Answer : [yes, no, rectangle, circle, r, g, b, o, k, y] +''' nb_questions = 10 dirs = './data' @@ -109,6 +120,7 @@ def build_dataset(): answer = 0 else: answer = 1 + norel_answers.append(answer) """Binary Relational questions""" @@ -125,7 +137,7 @@ def build_dataset(): """closest-to->rectangle/circle""" my_obj = objects[color][1] dist_list = [((my_obj - obj[1]) ** 2).sum() for obj in objects] - dist_list[dist_list.index(0)] = 999 + dist_list[dist_list.index(0)] = 99999 #999 closest = dist_list.index(min(dist_list)) if objects[closest][2] == 'r': answer = 2 @@ -258,6 +270,11 @@ def build_dataset(): norelations = (norel_questions, norel_answers) img = img/255. + + # save all image + # img_count = num_total + # cv2.imwrite(os.path.join(dirs+'/all dataset','{}.png'.format(img_count)), cv2.resize(img*255, (img_size,img_size))) + dataset = (img, ternary_relations, binary_relations, norelations) return dataset @@ -270,7 +287,7 @@ def build_dataset(): #img_count = 0 #cv2.imwrite(os.path.join(dirs,'{}.png'.format(img_count)), cv2.resize(train_datasets[0][0]*255, (512,512))) - +# translate(train_datasets[0]) print('saving datasets...') filename = os.path.join(dirs,'sort-of-clevr.pickle') diff --git a/translator.py b/translator.py index ac3d242..755e5de 100644 --- a/translator.py +++ b/translator.py @@ -1,36 +1,47 @@ import cv2 + def translate(dataset): + img, (rel_questions, rel_answers), (norel_questions, norel_answers) = dataset colors = ['red ', 'green ', 'blue ', 'orange ', 'gray ', 'yellow '] answer_sheet = ['yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6'] questions = rel_questions + norel_questions answers = rel_answers + norel_answers - - print rel_questions - print rel_answers - - - for question,answer in zip(questions,answers): + ''' + Question [ + obj 1 color: 'red':0, 'green':1, 'blue':2, 'orange':3, 'gray':4, 'yellow':5, + obj 2 color: 'red':6, 'green':7, 'blue':8, 'orange':9, 'gray':10, 'yellow':11, + 'no_rel':12, | 'rel':13, |'ternary':14, + 'sub-type[0]:15 query shape->rectangle/circle | closest-to->rectangle/circle | between->1~4', + 'sub-type[1]:16 query horizontal position->yes/no | furthest-from->rectangle/circle | is-on-band->yes/no', + 'sub-type[2]:17 query vertical position->yes/no | count->1~6 | count-obtuse-triangles->1~6' + ] ''' + + ''' 3 for question type(index: 12:14), 3 for question subtype(index: 15:17) ''' + for question, answer in zip(questions,answers): + query = '' query += colors[question.tolist()[0:6].index(1)] - - if question[6] == 1: - if question[8] == 1: + + if question[12] == 1: + if question[15] == 1: query += 'shape?' - if question[9] == 1: + if question[16] == 1: query += 'left?' - if question[10] == 1: + if question[17] == 1: query += 'up?' - if question[7] == 1: - if question[8] == 1: + + if question[13] == 1: + if question[15] == 1: query += 'closest shape?' - if question[9] == 1: + if question[16] == 1: query += 'furthest shape?' - if question[10] == 1: + if question[17] == 1: query += 'count?' ans = answer_sheet[answer] - print query,'==>', ans + print(query,'==>', ans) + #cv2.imwrite('sample.jpg',(img*255).astype(np.int32)) cv2.imshow('img',cv2.resize(img,(512,512))) cv2.waitKey(0) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..0605b28 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,2 @@ +from .project_ui import * +from .preditor_util import * \ No newline at end of file diff --git a/utils/preditor_util.py b/utils/preditor_util.py new file mode 100644 index 0000000..58a37d8 --- /dev/null +++ b/utils/preditor_util.py @@ -0,0 +1,115 @@ +from model import RN +import cv2 +import string +import numpy as np +import torch +import torch.nn.functional as F +import argparse + +parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example') +parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN', help='resume from model stored') +parser.add_argument('--batch-size', type=int, default=1, metavar='N', help='input batch size for training (default: 64)') +parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)') +parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.0001)') +parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') +parser.add_argument('--resume', type=str, help='resume from model stored') +parser.add_argument('--relation-type', type=str, default='binary', help='what kind of relations to learn. options: binary, ternary (default: binary)') + +args = parser.parse_args() +torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +args.cuda = False + +class RNPredictor(): + def __init__(self): + pass + + def __tensor_data(self, data): + image, question = data + image = torch.FloatTensor(image).unsqueeze(0) + image = image.permute(0, 3, 1, 2) + question = torch.FloatTensor([question]) + return image, question + + def __map_qustion_type(self, idx): + question_idx = 12 + question_sub_type = 15 + if idx == 1: + question_idx += 0 + question_sub_type += 0 + elif idx == 2: + question_sub_type += 1 + elif idx == 3: + question_sub_type += 2 + elif idx == 4: + question_idx += 1 + question_sub_type += 0 + elif idx == 5: + question_idx += 1 + question_sub_type += 1 + elif idx == 6: + question_idx += 1 + question_sub_type += 2 + return question_idx, question_sub_type + + def __convert_to_vector(self, question_idx, tokens): + color_map = { 'red':0, 'green':1, 'blue':2, 'orange':3, 'gray':4, 'yellow':5 } + question = [ 0 for i in range(18) ] + for word in tokens: + if word in color_map.keys(): + question[color_map[word]] = 1 + question_idx, question_sub_type = self.__map_qustion_type(question_idx) + # print(question_idx, question_sub_type) + question[question_idx] = 1 + question[question_sub_type] = 1 + return question + + def tokenize(self, sentence): + type_split = sentence.split('.') + question_idx = int(type_split[0]) + sentence = type_split[1] + for c in sentence: + if c in string.punctuation: + no_punctuation_sentence = sentence.replace(c, '') + tokens = no_punctuation_sentence.split() + question = self.__convert_to_vector(question_idx, tokens) + # print(question) + return question + + def predict(self, data): + classes = [ 'yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6' ] + model = RN(args) + checkpoint = torch.load('best_model/epoch_RN_20.pth') + model.load_state_dict(checkpoint) + model.eval() + input_image, input_question = self.__tensor_data(data) + output = model(input_image, input_question) + output = F.log_softmax(output, dim=0) + _, prediction = output.data.max(0) + # print(_, prediction) + return classes[prediction] + +''' +# testing +if __name__ == '__main__': + image = cv2.imread('image_RGB.jpg') + colors = [ 'red', 'green', 'blue', 'orange', 'gray', 'yellow' ] + + questions = [ + "1.What is the shape of the object ?", + "2.Is object placed on the left side of the image ?", + "3.Is object placed on the upside of the image ?", + "4.What is the shape of the object closest to the object ?", + "5.What is the shape of the object furthest to the object ?", + "6.How many objects have same shape with the object ?" ] + + rn_predictor = RNPredictor() + for quetion in questions: + for color in colors: + sentence = quetion.replace('', color) + print('Question:', sentence) + question = rn_predictor.tokenize(sentence) + result = rn_predictor.predict((image/255, question)) + print('Answer:', result) +''' \ No newline at end of file diff --git a/utils/project_ui.py b/utils/project_ui.py new file mode 100644 index 0000000..4934ce9 --- /dev/null +++ b/utils/project_ui.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +# Form implementation generated from reading ui file 'project_UI.ui' +# +# Created by: PyQt5 UI code generator 5.9.2 +# +# WARNING! All changes made in this file will be lost! +from PyQt5 import QtCore, QtWidgets,QtGui +from PyQt5.QtGui import QFont as qfont +from .shape_util import ShapeCanvas + +class Ui_MainWindow(object): + def setupUi(self, MainWindow): + MainWindow.setObjectName("MainWindow") + MainWindow.resize(740, 804) + self.setStyleSheet("background-color: rgb(224, 227, 210);") + self.centralwidget = QtWidgets.QWidget(MainWindow) + self.centralwidget.setObjectName("centralwidget") + self.pushButton = QtWidgets.QPushButton(self.centralwidget) + self.pushButton.setGeometry(QtCore.QRect(630, 650, 81, 31)) + self.pushButton.setStyleSheet("color: rgb(5,22,39)") + self.pushButton.setFont(qfont('Times', 10)) + self.pushButton.setObjectName("pushButton") + self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget) + self.pushButton_2.setGeometry(QtCore.QRect(630, 720, 81, 31)) + self.pushButton_2.setStyleSheet("color: rgb(5,22,39)") + self.pushButton_2.setFont(qfont('Times', 10)) + self.pushButton_2.setObjectName("pushButton_2") + + self.label = ShapeCanvas(self) + self.label.setGeometry(QtCore.QRect(70, 20, 600, 600)) + self.label.setAutoFillBackground(True) + self.label.setText("") + self.label.setObjectName("label") + + self.label2 = QtWidgets.QLabel(self.centralwidget) + self.label2.setGeometry(QtCore.QRect(10, 630, 80, 15)) + self.label2.setAutoFillBackground(True) + self.label2.setStyleSheet("color: rgb(5,22,39)") + self.label2.setFont(qfont('Times', 10 ,QtGui.QFont.Bold)) + self.label2.setText("Question:") + self.label2.setObjectName("label2") + + self.label3 = QtWidgets.QLabel(self.centralwidget) + self.label3.setGeometry(QtCore.QRect(10, 720, 70, 31)) + self.label3.setAutoFillBackground(True) + self.label3.setStyleSheet("color: rgb(5,22,39)") + self.label3.setFont(qfont('Times', 10 ,QtGui.QFont.Bold)) + self.label3.setText("Answer:") + self.label3.setObjectName("label3") + + self.comboBox = QtWidgets.QComboBox(self.centralwidget) + self.comboBox.setEnabled(True) + self.comboBox.setGeometry(QtCore.QRect(30, 650, 571, 60)) + self.comboBox.setStyleSheet("color: rgb(5,22,39)") + self.comboBox.setFont(qfont('Times', 12)) + self.comboBox.setEditable(True) + self.comboBox.lineEdit().setFont(qfont('Times', 12, QtGui.QFont.Bold)) + self.comboBox.setObjectName("comboBox") + self.comboBox.addItem("1.What is the shape of the object ?") + self.comboBox.addItem("2.Is object placed on the left side of the image ?") + self.comboBox.addItem("3.Is object placed on the up side of the image ?") + self.comboBox.addItem("4.What is the shape of the object closest to the object ?") + self.comboBox.addItem("5.What is the shape of the object furthest to the object ?") + self.comboBox.addItem("6.How many objects have same shape with the object ?") + ''' + self.textEdit = QtWidgets.QTextEdit(self.centralwidget) + self.textEdit.setStyleSheet("color: rgb(5,22,39)") + self.textEdit.setFont(qfont('Times', 12)) + self.textEdit.setGeometry(QtCore.QRect(30, 650, 571, 60)) + self.textEdit.setObjectName("textEdit") + ''' + self.label4 = QtWidgets.QLabel(self.centralwidget) + self.label4.setStyleSheet("color: rgb(5,22,39)") + self.label4.setFont(qfont('Times', 12)) + self.label4.setGeometry(QtCore.QRect(100, 720, 500, 31)) + self.label4.setObjectName("label4") + # test + # self.label4.setText("true") + self.label4.setText("") + MainWindow.setCentralWidget(self.centralwidget) + self.menubar = QtWidgets.QMenuBar(MainWindow) + self.menubar.setGeometry(QtCore.QRect(0, 0, 852, 26)) + self.menubar.setObjectName("menubar") + MainWindow.setMenuBar(self.menubar) + self.statusbar = QtWidgets.QStatusBar(MainWindow) + self.statusbar.setObjectName("statusbar") + MainWindow.setStatusBar(self.statusbar) + + self.retranslateUi(MainWindow) + QtCore.QMetaObject.connectSlotsByName(MainWindow) + + def retranslateUi(self, MainWindow): + _translate = QtCore.QCoreApplication.translate + MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) + self.pushButton.setText(_translate("MainWindow", "OK")) + self.pushButton_2.setText(_translate("MainWindow", "Random")) + diff --git a/utils/shape_util.py b/utils/shape_util.py new file mode 100644 index 0000000..a597525 --- /dev/null +++ b/utils/shape_util.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jan 13 18:03:34 2022 + +@author: helen +""" + +from PyQt5 import QtWidgets +from PyQt5.QtCore import QRect, Qt +from PyQt5.QtGui import QPainter, QPen, QColor, QPixmap +import random + +class Shape(): + def __init__(self): + self.object_shape = [] + self.object_x = [ 17, 114, 211, 308, 405, 502 ] + self.object_y = [ 500, 500, 500, 500, 500, 500 ] + self.object_color = [ "#FF0000", # r + "#00FF00", # g + "#0000FF", # b + "#ff9d00", # o + "#808080", # k + "#FFFF00" ] # y + + def random_shape(self): + random.shuffle(self.object_color) + self.object_shape.clear() + for i in range(6): + self.object_shape.append(random.choice(['rectangle', 'circle'])) + +class ShapeCanvas(QtWidgets.QLabel): + def __init__(self, parent=None): + super().__init__(parent) + + self.canvas = QPixmap(600, 600) + self.canvas.fill(QColor("white")) + self.setPixmap(self.canvas) + + self.shape = Shape() + self.shape.random_shape() + self.shape_size = 80 + self.drag_idx = -1 + + def paintEvent(self, event): + super(ShapeCanvas, self).paintEvent(event) + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing) + + for i in range(6): + color = QColor(self.shape.object_color[i]) + pen = QPen(color, 1) + painter.setBrush(color) + painter.setPen(pen) + if self.shape.object_shape[i] == 'rectangle': + painter.drawRect(QRect(self.shape.object_x[i],self.shape.object_y[i],self.shape_size,self.shape_size)) + else: + painter.drawEllipse(QRect(self.shape.object_x[i],self.shape.object_y[i],self.shape_size,self.shape_size)) + + def mousePressEvent(self, event): + if event.button() == Qt.LeftButton: + self.drag_idx = self.drag_index(event.pos()) + # print('drag_index:', self.drag_idx) + + def drag_index(self, position): + for i in range(6): + if self.is_pointSelected(self.shape.object_x[i],self.shape.object_y[i], position): + return i + return -1 + + def is_pointSelected(self, point_x ,point_y, position): + # point range + x_min =point_x - self.shape_size + x_max = point_x + self.shape_size + y_min = point_y - self.shape_size + y_max = point_y + self.shape_size + + # check control points in select range position + x, y = position.x(), position.y() + if x_min < x < x_max and y_min < y < y_max: + return True + return False + + def mouseMoveEvent(self, event): + if self.drag_idx != -1: + self.shape.object_x[self.drag_idx] = event.pos().x() + self.shape.object_y[self.drag_idx] = event.pos().y() + self.update() + + def mouseReleaseEvent(self, event): + self.drag_idx = -1 + self.update() + + + def pushButton2_clicked(self): # Random + self.shape.random_shape() + self.repaint() + + + + \ No newline at end of file