1010from losses import LossMSE , LossMEP , SSIMLoss
1111from training import Trainer
1212
13-
1413# Import Necessary Libraries
1514import os
1615import traceback
1716import torch
17+ import torch .multiprocessing as mp
18+ import torch .distributed as dist
19+ import platform
1820
1921# Define Working Directories
2022grayscale_dir = '../Dataset/Greyscale'
2123rgb_dir = '../Dataset/RGB'
2224
2325# Define Universal Parameters
24- image_height = 400
25- image_width = 600
26+ image_height = 4000
27+ image_width = 6000
2628batch_size = 2
2729
28-
29- def main ():
30+ def get_backend ():
31+ system_type = platform .system ()
32+ if system_type == "Linux" :
33+ return "nccl"
34+ else :
35+ return "gloo"
36+
37+ def main_worker (rank , world_size ):
38+ # Set environment variables
39+ os .environ ['MASTER_ADDR' ] = 'localhost'
40+ os .environ ['MASTER_PORT' ] = '12345'
41+ # Initialize the distributed environment.
42+ torch .manual_seed (0 )
43+ torch .backends .cudnn .enabled = True
44+ torch .backends .cudnn .benchmark = True
45+ dist .init_process_group (backend = get_backend (), init_method = "env://" , world_size = world_size , rank = rank )
46+ main (rank ) # Call the existing main function.
47+
48+ def main (rank ):
3049 # Initialize Dataset Object (PyTorch Tensors)
3150 try :
3251 dataset = CustomDataset (grayscale_dir , rgb_dir , (image_height , image_width ), batch_size )
33- print ('Importing Dataset Complete.' )
52+ if rank == 0 :
53+ print ('Importing Dataset Complete.' )
3454 except Exception as e :
35- print (f"Importing Dataset In-Complete : \n { e } " )
55+ if rank == 0 :
56+ print (f"Importing Dataset In-Complete : \n { e } " )
57+ if rank == 0 :
58+ print ('-' * 20 ) # Makes Output Readable
3659 # Import Loss Functions
3760 try :
3861 loss_mse = LossMSE () # Mean Squared Error Loss
3962 loss_mep = LossMEP (alpha = 0.4 ) # Maximum Entropy Loss
4063 loss_ssim = SSIMLoss () # Structural Similarity Index Measure Loss
41- print ('Importing Loss Functions Complete.' )
64+ if rank == 0 :
65+ print ('Importing Loss Functions Complete.' )
4266 except Exception as e :
43- print (f"Importing Loss Functions In-Complete : \n { e } " )
44- print ('-' * 20 ) # Makes Output Readable
67+ if rank == 0 :
68+ print (f"Importing Loss Functions In-Complete : \n { e } " )
69+ if rank == 0 :
70+ print ('-' * 20 ) # Makes Output Readable
4571
4672 # Initialize AutoEncoder Model and Import Dataloader (Training, Validation)
4773 data_autoencoder_train , data_autoencoder_val = dataset .get_autoencoder_batches (val_split = 0.2 )
48- print ('AutoEncoder Model Data Imported.' )
74+ if rank == 0 :
75+ print ('AutoEncoder Model Data Imported.' )
4976 model_autoencoder = Grey2RGBAutoEncoder ()
50- print ('AutoEncoder Model Initialized.' )
51- print ('-' * 20 ) # Makes Output Readable
77+ if rank == 0 :
78+ print ('AutoEncoder Model Initialized.' )
79+ print ('-' * 20 ) # Makes Output Readable
5280
5381 # Initialize LSTM Model and Import Dataloader (Training, Validation)
5482 data_lstm_train , data_lstm_val = dataset .get_lstm_batches (val_split = 0.25 , sequence_length = 2 )
55- print ('LSTM Model Data Imported.' )
83+ if rank == 0 :
84+ print ('LSTM Model Data Imported.' )
5685 model_lstm = ConvLSTM (input_dim = 1 , hidden_dims = [1 ,1 ,1 ], kernel_size = (3 , 3 ), num_layers = 3 , alpha = 0.5 )
57- print ('LSTM Model Initialized.' )
58- print ('-' * 20 ) # Makes Output Readable
86+ if rank == 0 :
87+ print ('LSTM Model Initialized.' )
88+ print ('-' * 20 ) # Makes Output Readable
5989
6090 '''
6191 Initialize Trainer Objects
6292 '''
6393 # Method 1 : Baseline : Mean Squared Error Loss for AutoEncoder and LSTM
6494 os .makedirs ('../Models/Method1' , exist_ok = True ) # Creating Directory for Model Saving
6595 model_save_path_ae = '../Models/Method1/model_autoencoder_m1.pth'
66- trainer_autoencoder_baseline = Trainer (model_autoencoder , loss_mse , optimizer = torch .optim .Adam (model_autoencoder .parameters (), lr = 0.001 ), model_save_path = model_save_path_ae )
67- print ('Method-1 AutoEncoder Trainer Initialized.' )
96+ trainer_autoencoder_baseline = Trainer (model = model_autoencoder ,
97+ loss_function = loss_mse ,
98+ optimizer = torch .optim .Adam (model_autoencoder .parameters (), lr = 0.001 ),
99+ model_save_path = model_save_path_ae ,
100+ rank = rank )
101+ if rank == 0 :
102+ print ('Method-1 AutoEncoder Trainer Initialized.' )
68103 model_save_path_lstm = '../Models/Method1/model_lstm_m1.pth'
69- trainer_lstm_baseline = Trainer (model_lstm , loss_mse , optimizer = torch .optim .Adam (model_lstm .parameters (), lr = 0.001 ), model_save_path = model_save_path_lstm )
70- print ('Method-1 LSTM Trainer Initialized.' )
71- print ('-' * 10 ) # Makes Output Readable
104+ trainer_lstm_baseline = Trainer (model = model_lstm ,
105+ loss_function = loss_mse ,
106+ optimizer = torch .optim .Adam (model_lstm .parameters (), lr = 0.001 ),
107+ model_save_path = model_save_path_lstm ,
108+ rank = rank )
109+ if rank == 0 :
110+ print ('Method-1 LSTM Trainer Initialized.' )
111+ print ('-' * 10 ) # Makes Output Readable
72112
73113 # Method 2 : Composite Loss (MSE + MaxEnt) for AutoEncoder and Mean Squared Error Loss for LSTM
74114 os .makedirs ('../Models/Method2' , exist_ok = True ) # Creating Directory for Model Saving
75115 model_save_path_ae = '../Models/Method2/model_autoencoder_m2.pth'
76- trainer_autoencoder_m2 = Trainer (model = model_autoencoder , loss_function = loss_mep , optimizer = torch .optim .Adam (model_autoencoder .parameters (), lr = 0.001 ), model_save_path = model_save_path_ae )
77- print ('Method-2 AutoEncoder Trainer Initialized.' )
78- print ('Method-2 LSTM == Method-1 LSTM' )
79- print ('-' * 10 ) # Makes Output Readable
116+ trainer_autoencoder_m2 = Trainer (model = model_autoencoder ,
117+ loss_function = loss_mep ,
118+ optimizer = torch .optim .Adam (model_autoencoder .parameters (), lr = 0.001 ),
119+ model_save_path = model_save_path_ae ,
120+ rank = rank )
121+ if rank == 0 :
122+ print ('Method-2 AutoEncoder Trainer Initialized.' )
123+ print ('Method-2 LSTM == Method-1 LSTM' )
124+ print ('-' * 10 ) # Makes Output Readable
80125
81126 # Method 3 : Mean Squared Error Loss for AutoEncoder and SSIM Loss for LSTM
82127 os .makedirs ('../Models/Method3' , exist_ok = True ) # Creating Directory for Model Saving
83- print ('Method-3 AutoEncoder == Method-1 AutoEncoder' )
128+ if rank == 0 :
129+ print ('Method-3 AutoEncoder == Method-1 AutoEncoder' )
84130 model_save_path_lstm = '../Models/Method3/model_lstm_m3.pth'
85- trainer_lstm_m3 = Trainer (model_lstm , loss_ssim , optimizer = torch .optim .Adam (model_lstm .parameters (), lr = 0.001 ), model_save_path = model_save_path_lstm )
86- print ('Method-3 LSTM Trainer Initialized.' )
87- print ('-' * 10 ) # Makes Output Readable
131+ trainer_lstm_m3 = Trainer (model = model_lstm ,
132+ loss_function = loss_ssim ,
133+ optimizer = torch .optim .Adam (model_lstm .parameters (), lr = 0.001 ),
134+ model_save_path = model_save_path_lstm ,
135+ rank = rank )
136+ if rank == 0 :
137+ print ('Method-3 LSTM Trainer Initialized.' )
138+ print ('-' * 10 ) # Makes Output Readable
88139
89140 # Method 4 : Proposed Method : Composite Loss (MSE + MaxEnt) for AutoEncoder and SSIM Loss for LSTM
90- print ( 'Method-4 AutoEncoder == Method-2 AutoEncoder' )
91- print ('Method-4 LSTM == Method-3 LSTM ' )
92-
93- print ('-' * 20 ) # Makes Output Readable
141+ if rank == 0 :
142+ print ('Method-4 AutoEncoder == Method-2 AutoEncoder ' )
143+ print ( 'Method-4 LSTM == Method-3 LSTM' )
144+ print ('-' * 20 ) # Makes Output Readable
94145
95146
96147 '''
@@ -99,55 +150,84 @@ def main():
99150 # Method-1
100151 try :
101152 epochs = 1
102- print ('Method-1 AutoEncoder Training Start' )
153+ if rank == 0 :
154+ print ('Method-1 AutoEncoder Training Start' )
103155 model_autoencoder_m1 = trainer_autoencoder_baseline .train_autoencoder (epochs , data_autoencoder_train , data_autoencoder_val )
104- print ('Method-1 AutoEncoder Training Complete.' )
156+ if rank == 0 :
157+ print ('Method-1 AutoEncoder Training Complete.' )
105158 except Exception as e :
106- print (f"Method-1 AutoEncoder Training Error : \n { e } " )
159+ if rank == 0 :
160+ print (f"Method-1 AutoEncoder Training Error : \n { e } " )
107161 traceback .print_exc ()
108- print ('-' * 10 ) # Makes Output Readable
162+ finally :
163+ if rank == 0 :
164+ trainer_autoencoder_baseline .cleanup_ddp ()
165+ if rank == 0 :
166+ print ('-' * 10 ) # Makes Output Readable
109167 try :
110168 epochs = 1
111- print ('Method-1 LSTM Training Start' )
169+ if rank == 0 :
170+ print ('Method-1 LSTM Training Start' )
112171 model_lstm_m1 = trainer_lstm_baseline .train_lstm (epochs , data_lstm_train , data_lstm_val )
113- print ('Method-1 LSTM Training Complete.' )
172+ if rank == 0 :
173+ print ('Method-1 LSTM Training Complete.' )
114174 except Exception as e :
115- print (f"Method-1 LSTM Training Error : \n { e } " )
175+ if rank == 0 :
176+ print (f"Method-1 LSTM Training Error : \n { e } " )
116177 traceback .print_exc ()
117- print ('-' * 20 ) # Makes Output Readable
178+ finally :
179+ if rank == 0 :
180+ trainer_lstm_baseline .cleanup_ddp ()
181+ if rank == 0 :
182+ print ('-' * 20 ) # Makes Output Readable
118183
119184 # Method-2
120185 try :
121186 epochs = 1
122- print ('Method-2 AutoEncoder Training Start' )
187+ if rank == 0 :
188+ print ('Method-2 AutoEncoder Training Start' )
123189 model_autoencoder_m2 = trainer_autoencoder_m2 .train_autoencoder (epochs , data_autoencoder_train , data_autoencoder_val )
124- print ('Method-2 AutoEncoder Training Complete.' )
190+ if rank == 0 :
191+ print ('Method-2 AutoEncoder Training Complete.' )
125192 except Exception as e :
126- print (f"Method-2 AutoEncoder Training Error : \n { e } " )
193+ if rank == 0 :
194+ print (f"Method-2 AutoEncoder Training Error : \n { e } " )
127195 traceback .print_exc ()
128- print ('-' * 10 ) # Makes Output Readable
129- print ("Method-2 LSTM == Method-1 LSTM, No Need To Train Again." )
130- print ('-' * 20 ) # Makes Output Readable
196+ finally :
197+ trainer_autoencoder_m2 .cleanup_ddp ()
198+ if rank == 0 :
199+ print ('-' * 10 ) # Makes Output Readable
200+ print ("Method-2 LSTM == Method-1 LSTM, No Need To Train Again." )
201+ print ('-' * 20 ) # Makes Output Readable
131202
132203 # Method-3
133- print ("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again." )
134- print ('-' * 10 ) # Makes Output Readable
204+ if rank == 0 :
205+ print ("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again." )
206+ print ('-' * 10 ) # Makes Output Readable
135207 try :
136208 epochs = 1
137- print ('Method-3 LSTM Training Start.' )
209+ if rank == 0 :
210+ print ('Method-3 LSTM Training Start.' )
138211 model_lstm_m3 = trainer_lstm_m3 .train_lstm (epochs , data_lstm_train , data_lstm_val )
139- print ('Method-3 LSTM Training Complete.' )
212+ if rank == 0 :
213+ print ('Method-3 LSTM Training Complete.' )
140214 except Exception as e :
141- print (f"Method-3 LSTM Training Error : \n { e } " )
215+ if rank == 0 :
216+ print (f"Method-3 LSTM Training Error : \n { e } " )
142217 traceback .print_exc ()
143- print ('-' * 20 ) # Makes Output Readable
218+ finally :
219+ trainer_lstm_m3 .cleanup_ddp ()
220+ if rank == 0 :
221+ print ('-' * 20 ) # Makes Output Readable
144222
145223 # Method-4
146- print ("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again." )
147- print ('-' * 10 ) # Makes Output Readable
148- print ("Method-4 LSTM == Method-3 LSTM, No Need To Train Again." )
149- print ('-' * 20 ) # Makes Output Readable
224+ if rank == 0 :
225+ print ("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again." )
226+ print ('-' * 10 ) # Makes Output Readable
227+ print ("Method-4 LSTM == Method-3 LSTM, No Need To Train Again." )
228+ print ('-' * 20 ) # Makes Output Readable
150229
151230
152231if __name__ == '__main__' :
153- main ()
232+ world_size = torch .cuda .device_count () # Number of available GPUs
233+ mp .spawn (main_worker , args = (world_size ,), nprocs = world_size , join = True )
0 commit comments