Skip to content

Commit 8594786

Browse files
committed
Merge branch 'main' of github.com:iSiddharth20/Generative-AI-Based-Spatio-Temporal-Fusion
2 parents f2a9f11 + cbb3657 commit 8594786

File tree

3 files changed

+162
-73
lines changed

3 files changed

+162
-73
lines changed

Code/autoencoder_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class Grey2RGBAutoEncoder(nn.Module):
1616
def __init__(self):
1717
super(Grey2RGBAutoEncoder, self).__init__()
1818
# Define the Encoder
19-
self.encoder = self._make_layers([1, 64, 128, 256])
19+
self.encoder = self._make_layers([1, 8, 16, 32])
2020
# Define the Decoder
21-
self.decoder = self._make_layers([256, 128, 64, 3], decoder=True)
21+
self.decoder = self._make_layers([32, 16, 8, 3], decoder=True)
2222

2323
# Helper function to create the encoder or decoder layers.
2424
def _make_layers(self, channels, decoder=False):

Code/main.py

Lines changed: 138 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -10,87 +10,138 @@
1010
from losses import LossMSE, LossMEP, SSIMLoss
1111
from training import Trainer
1212

13-
1413
# Import Necessary Libraries
1514
import os
1615
import traceback
1716
import torch
17+
import torch.multiprocessing as mp
18+
import torch.distributed as dist
19+
import platform
1820

1921
# Define Working Directories
2022
grayscale_dir = '../Dataset/Greyscale'
2123
rgb_dir = '../Dataset/RGB'
2224

2325
# Define Universal Parameters
24-
image_height = 400
25-
image_width = 600
26+
image_height = 4000
27+
image_width = 6000
2628
batch_size = 2
2729

28-
29-
def main():
30+
def get_backend():
31+
system_type = platform.system()
32+
if system_type == "Linux":
33+
return "nccl"
34+
else:
35+
return "gloo"
36+
37+
def main_worker(rank, world_size):
38+
# Set environment variables
39+
os.environ['MASTER_ADDR'] = 'localhost'
40+
os.environ['MASTER_PORT'] = '12345'
41+
# Initialize the distributed environment.
42+
torch.manual_seed(0)
43+
torch.backends.cudnn.enabled = True
44+
torch.backends.cudnn.benchmark = True
45+
dist.init_process_group(backend=get_backend(), init_method="env://", world_size=world_size, rank=rank)
46+
main(rank) # Call the existing main function.
47+
48+
def main(rank):
3049
# Initialize Dataset Object (PyTorch Tensors)
3150
try:
3251
dataset = CustomDataset(grayscale_dir, rgb_dir, (image_height, image_width), batch_size)
33-
print('Importing Dataset Complete.')
52+
if rank == 0:
53+
print('Importing Dataset Complete.')
3454
except Exception as e:
35-
print(f"Importing Dataset In-Complete : \n{e}")
55+
if rank == 0:
56+
print(f"Importing Dataset In-Complete : \n{e}")
57+
if rank == 0:
58+
print('-'*20) # Makes Output Readable
3659
# Import Loss Functions
3760
try:
3861
loss_mse = LossMSE() # Mean Squared Error Loss
3962
loss_mep = LossMEP(alpha=0.4) # Maximum Entropy Loss
4063
loss_ssim = SSIMLoss() # Structural Similarity Index Measure Loss
41-
print('Importing Loss Functions Complete.')
64+
if rank == 0:
65+
print('Importing Loss Functions Complete.')
4266
except Exception as e:
43-
print(f"Importing Loss Functions In-Complete : \n{e}")
44-
print('-'*20) # Makes Output Readable
67+
if rank == 0:
68+
print(f"Importing Loss Functions In-Complete : \n{e}")
69+
if rank == 0:
70+
print('-'*20) # Makes Output Readable
4571

4672
# Initialize AutoEncoder Model and Import Dataloader (Training, Validation)
4773
data_autoencoder_train, data_autoencoder_val = dataset.get_autoencoder_batches(val_split=0.2)
48-
print('AutoEncoder Model Data Imported.')
74+
if rank == 0:
75+
print('AutoEncoder Model Data Imported.')
4976
model_autoencoder = Grey2RGBAutoEncoder()
50-
print('AutoEncoder Model Initialized.')
51-
print('-'*20) # Makes Output Readable
77+
if rank == 0:
78+
print('AutoEncoder Model Initialized.')
79+
print('-'*20) # Makes Output Readable
5280

5381
# Initialize LSTM Model and Import Dataloader (Training, Validation)
5482
data_lstm_train, data_lstm_val = dataset.get_lstm_batches(val_split=0.25, sequence_length=2)
55-
print('LSTM Model Data Imported.')
83+
if rank == 0:
84+
print('LSTM Model Data Imported.')
5685
model_lstm = ConvLSTM(input_dim=1, hidden_dims=[1,1,1], kernel_size=(3, 3), num_layers=3, alpha=0.5)
57-
print('LSTM Model Initialized.')
58-
print('-'*20) # Makes Output Readable
86+
if rank == 0:
87+
print('LSTM Model Initialized.')
88+
print('-'*20) # Makes Output Readable
5989

6090
'''
6191
Initialize Trainer Objects
6292
'''
6393
# Method 1 : Baseline : Mean Squared Error Loss for AutoEncoder and LSTM
6494
os.makedirs('../Models/Method1', exist_ok=True) # Creating Directory for Model Saving
6595
model_save_path_ae = '../Models/Method1/model_autoencoder_m1.pth'
66-
trainer_autoencoder_baseline = Trainer(model_autoencoder, loss_mse, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae)
67-
print('Method-1 AutoEncoder Trainer Initialized.')
96+
trainer_autoencoder_baseline = Trainer(model=model_autoencoder,
97+
loss_function=loss_mse,
98+
optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001),
99+
model_save_path=model_save_path_ae,
100+
rank=rank)
101+
if rank == 0:
102+
print('Method-1 AutoEncoder Trainer Initialized.')
68103
model_save_path_lstm = '../Models/Method1/model_lstm_m1.pth'
69-
trainer_lstm_baseline = Trainer(model_lstm, loss_mse, optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), model_save_path=model_save_path_lstm)
70-
print('Method-1 LSTM Trainer Initialized.')
71-
print('-'*10) # Makes Output Readable
104+
trainer_lstm_baseline = Trainer(model=model_lstm,
105+
loss_function=loss_mse,
106+
optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001),
107+
model_save_path=model_save_path_lstm,
108+
rank=rank)
109+
if rank == 0:
110+
print('Method-1 LSTM Trainer Initialized.')
111+
print('-'*10) # Makes Output Readable
72112

73113
# Method 2 : Composite Loss (MSE + MaxEnt) for AutoEncoder and Mean Squared Error Loss for LSTM
74114
os.makedirs('../Models/Method2', exist_ok=True) # Creating Directory for Model Saving
75115
model_save_path_ae = '../Models/Method2/model_autoencoder_m2.pth'
76-
trainer_autoencoder_m2 = Trainer(model=model_autoencoder, loss_function=loss_mep, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae)
77-
print('Method-2 AutoEncoder Trainer Initialized.')
78-
print('Method-2 LSTM == Method-1 LSTM')
79-
print('-'*10) # Makes Output Readable
116+
trainer_autoencoder_m2 = Trainer(model=model_autoencoder,
117+
loss_function=loss_mep,
118+
optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001),
119+
model_save_path=model_save_path_ae,
120+
rank=rank)
121+
if rank == 0:
122+
print('Method-2 AutoEncoder Trainer Initialized.')
123+
print('Method-2 LSTM == Method-1 LSTM')
124+
print('-'*10) # Makes Output Readable
80125

81126
# Method 3 : Mean Squared Error Loss for AutoEncoder and SSIM Loss for LSTM
82127
os.makedirs('../Models/Method3', exist_ok=True) # Creating Directory for Model Saving
83-
print('Method-3 AutoEncoder == Method-1 AutoEncoder')
128+
if rank == 0:
129+
print('Method-3 AutoEncoder == Method-1 AutoEncoder')
84130
model_save_path_lstm = '../Models/Method3/model_lstm_m3.pth'
85-
trainer_lstm_m3 = Trainer(model_lstm, loss_ssim, optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), model_save_path=model_save_path_lstm)
86-
print('Method-3 LSTM Trainer Initialized.')
87-
print('-'*10) # Makes Output Readable
131+
trainer_lstm_m3 = Trainer(model=model_lstm,
132+
loss_function=loss_ssim,
133+
optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001),
134+
model_save_path=model_save_path_lstm,
135+
rank=rank)
136+
if rank == 0:
137+
print('Method-3 LSTM Trainer Initialized.')
138+
print('-'*10) # Makes Output Readable
88139

89140
# Method 4 : Proposed Method : Composite Loss (MSE + MaxEnt) for AutoEncoder and SSIM Loss for LSTM
90-
print('Method-4 AutoEncoder == Method-2 AutoEncoder')
91-
print('Method-4 LSTM == Method-3 LSTM')
92-
93-
print('-'*20) # Makes Output Readable
141+
if rank == 0:
142+
print('Method-4 AutoEncoder == Method-2 AutoEncoder')
143+
print('Method-4 LSTM == Method-3 LSTM')
144+
print('-'*20) # Makes Output Readable
94145

95146

96147
'''
@@ -99,55 +150,84 @@ def main():
99150
# Method-1
100151
try:
101152
epochs = 1
102-
print('Method-1 AutoEncoder Training Start')
153+
if rank == 0:
154+
print('Method-1 AutoEncoder Training Start')
103155
model_autoencoder_m1 = trainer_autoencoder_baseline.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val)
104-
print('Method-1 AutoEncoder Training Complete.')
156+
if rank == 0:
157+
print('Method-1 AutoEncoder Training Complete.')
105158
except Exception as e:
106-
print(f"Method-1 AutoEncoder Training Error : \n{e}")
159+
if rank == 0:
160+
print(f"Method-1 AutoEncoder Training Error : \n{e}")
107161
traceback.print_exc()
108-
print('-'*10) # Makes Output Readable
162+
finally:
163+
if rank == 0:
164+
trainer_autoencoder_baseline.cleanup_ddp()
165+
if rank == 0:
166+
print('-'*10) # Makes Output Readable
109167
try:
110168
epochs = 1
111-
print('Method-1 LSTM Training Start')
169+
if rank == 0:
170+
print('Method-1 LSTM Training Start')
112171
model_lstm_m1 = trainer_lstm_baseline.train_lstm(epochs, data_lstm_train, data_lstm_val)
113-
print('Method-1 LSTM Training Complete.')
172+
if rank == 0:
173+
print('Method-1 LSTM Training Complete.')
114174
except Exception as e:
115-
print(f"Method-1 LSTM Training Error : \n{e}")
175+
if rank == 0:
176+
print(f"Method-1 LSTM Training Error : \n{e}")
116177
traceback.print_exc()
117-
print('-'*20) # Makes Output Readable
178+
finally:
179+
if rank == 0:
180+
trainer_lstm_baseline.cleanup_ddp()
181+
if rank == 0:
182+
print('-'*20) # Makes Output Readable
118183

119184
# Method-2
120185
try:
121186
epochs = 1
122-
print('Method-2 AutoEncoder Training Start')
187+
if rank == 0:
188+
print('Method-2 AutoEncoder Training Start')
123189
model_autoencoder_m2 = trainer_autoencoder_m2.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val)
124-
print('Method-2 AutoEncoder Training Complete.')
190+
if rank == 0:
191+
print('Method-2 AutoEncoder Training Complete.')
125192
except Exception as e:
126-
print(f"Method-2 AutoEncoder Training Error : \n{e}")
193+
if rank == 0:
194+
print(f"Method-2 AutoEncoder Training Error : \n{e}")
127195
traceback.print_exc()
128-
print('-'*10) # Makes Output Readable
129-
print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.")
130-
print('-'*20) # Makes Output Readable
196+
finally:
197+
trainer_autoencoder_m2.cleanup_ddp()
198+
if rank == 0:
199+
print('-'*10) # Makes Output Readable
200+
print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.")
201+
print('-'*20) # Makes Output Readable
131202

132203
# Method-3
133-
print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.")
134-
print('-'*10) # Makes Output Readable
204+
if rank == 0:
205+
print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.")
206+
print('-'*10) # Makes Output Readable
135207
try:
136208
epochs = 1
137-
print('Method-3 LSTM Training Start.')
209+
if rank == 0:
210+
print('Method-3 LSTM Training Start.')
138211
model_lstm_m3 = trainer_lstm_m3.train_lstm(epochs, data_lstm_train, data_lstm_val)
139-
print('Method-3 LSTM Training Complete.')
212+
if rank == 0:
213+
print('Method-3 LSTM Training Complete.')
140214
except Exception as e:
141-
print(f"Method-3 LSTM Training Error : \n{e}")
215+
if rank == 0:
216+
print(f"Method-3 LSTM Training Error : \n{e}")
142217
traceback.print_exc()
143-
print('-'*20) # Makes Output Readable
218+
finally:
219+
trainer_lstm_m3.cleanup_ddp()
220+
if rank == 0:
221+
print('-'*20) # Makes Output Readable
144222

145223
# Method-4
146-
print("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again.")
147-
print('-'*10) # Makes Output Readable
148-
print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.")
149-
print('-'*20) # Makes Output Readable
224+
if rank == 0:
225+
print("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again.")
226+
print('-'*10) # Makes Output Readable
227+
print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.")
228+
print('-'*20) # Makes Output Readable
150229

151230

152231
if __name__ == '__main__':
153-
main()
232+
world_size = torch.cuda.device_count() # Number of available GPUs
233+
mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True)

Code/training.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,37 @@
1010

1111
# Import Necessary Libraries
1212
import torch
13-
import torch.nn as nn
13+
from torch.nn.parallel import DistributedDataParallel as DDP
14+
import torch.distributed as dist
1415

1516
# Define Training Class
1617
class Trainer():
17-
def __init__(self, model, loss_function, optimizer=None, model_save_path=None):
18-
# Use All Available CUDA GPUs for Training (if Available)
19-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20-
if torch.cuda.device_count() > 1:
21-
model = nn.DataParallel(model)
18+
def __init__(self, model, loss_function, optimizer=None, model_save_path=None, rank=None):
19+
self.rank = rank # Rank of the current process
20+
self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
2221
self.model = model.to(self.device)
2322
# Define the loss function
2423
self.loss_function = loss_function
2524
# Define the optimizer
2625
self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(self.model.parameters(), lr=0.001)
26+
# Wrap model with DDP
27+
if torch.cuda.device_count() > 1 and rank is not None:
28+
self.model = DDP(self.model, device_ids=[rank], find_unused_parameters=True)
2729
# Define the path to save the model
28-
self.model_save_path = model_save_path
30+
self.model_save_path = model_save_path if rank == 0 else None # Only save on master process
31+
32+
def cleanup_ddp(self):
33+
if dist.is_initialized():
34+
dist.destroy_process_group()
2935

3036
def save_model(self):
31-
# Save the model
32-
torch.save(self.model.state_dict(), self.model_save_path)
37+
if self.rank == 0:
38+
# Save the model
39+
torch.save(self.model.state_dict(), self.model_save_path)
3340

3441
def train_autoencoder(self, epochs, train_loader, val_loader):
3542
# Print Names of All Available GPUs (if any) to Train the Model
36-
if torch.cuda.device_count() > 0:
43+
if torch.cuda.device_count() > 0 and self.rank == 0:
3744
gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
3845
print("\tGPUs being used for Training : ",gpu_names)
3946
best_val_loss = float('inf')
@@ -54,7 +61,8 @@ def train_autoencoder(self, epochs, train_loader, val_loader):
5461
val_loss = sum(self.loss_function(self.model(input.to(self.device)), target.to(self.device)).item() for input, target in val_loader) # Compute Total Validation Loss
5562
val_loss /= len(val_loader) # Compute Average Validation Loss
5663
# Print epochs and losses
57-
print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
64+
if self.rank == 0:
65+
print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
5866
# If the current validation loss is lower than the best validation loss, save the model
5967
if val_loss < best_val_loss:
6068
best_val_loss = val_loss # Update the best validation loss
@@ -64,7 +72,7 @@ def train_autoencoder(self, epochs, train_loader, val_loader):
6472

6573
def train_lstm(self, epochs, train_loader, val_loader):
6674
# Print Names of All Available GPUs (if any) to Train the Model
67-
if torch.cuda.device_count() > 0:
75+
if torch.cuda.device_count() > 0 and self.rank == 0:
6876
gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
6977
print("\tGPUs being used for Training : ",gpu_names)
7078
best_val_loss = float('inf')
@@ -88,7 +96,8 @@ def train_lstm(self, epochs, train_loader, val_loader):
8896
val_loss += self.loss_function(output_sequence, target_sequence).item() # Accumulate loss
8997
val_loss /= len(val_loader) # Average validation loss
9098
# Print epochs and losses
91-
print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
99+
if self.rank == 0:
100+
print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}')
92101
# Model saving based on validation loss
93102
if val_loss < best_val_loss:
94103
best_val_loss = val_loss

0 commit comments

Comments
 (0)