-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
The original code of function get_similarity_matrix is:
def get_similarity_matrix(self, batch_x):
sample = batch_x.squeeze(-1) #[bsz, in_len]
diff = sample.unsqueeze(1) - sample.unsqueeze(0)
# Compute the Euclidean distance (squared)
dist_squared = torch.sum(diff ** 2, dim=-1) #[bsz, bsz]
param = torch.max(dist_squared)
euc_similarity = torch.exp(-5 * dist_squared /param )
return euc_similarity.to(self.device)How ever, it doesn't work in multiple time series. So I refer to "https://github.com/RakibulHaqueSajal/CCM_Module.git".
But still, the implementation of funciton get_similarity_matrix is ambiguous.
Can you tell us how to exactly implement this function?
Is this version correct?
def get_similarity_matrix(self, batch_data, sigma=5.0):
x = batch_data.transpose(1, 2)
cdist = torch.cdist(x, x, p=2.0)
# Similarity matrix per input example
sim_matrix = torch.exp(-cdist / (2 * sigma ** 2))
return sim_matrix.mean(dim=0).to(self.device, dtype=torch.float32)OR
def get_similarity_matrix(self, batch_data, sigma=5.0):
batch_len, seq_len, num_channels = batch_data.shape
similarity_matrix = torch.zeros((num_channels, num_channels), device=batch_data.device)
# Compute point-by-point differences along the sequence length
time_diffs = batch_data[:, 1:, :] - batch_data[:, :-1, :] # Shape: (batch_len, seq_len-1, channel)
# Compute mean of these differences over batch and sequence length
channel_representations = time_diffs.mean(dim=(0, 1)) # Shape: (channel,)
# Compute pairwise similarity
for i in range(num_channels):
for j in range(num_channels):
diff = torch.norm(channel_representations[i] - channel_representations[j]) ** 2
similarity_matrix[i, j] = torch.exp(-diff / (2 * sigma ** 2))
return similarity_matrix.to(self.device)Metadata
Metadata
Assignees
Labels
No labels