Skip to content

How does get_similarity_matrix function works in multiple time series? #4

@uncertainty-mix

Description

@uncertainty-mix

The original code of function get_similarity_matrix is:

def get_similarity_matrix(self, batch_x):
    sample = batch_x.squeeze(-1)  #[bsz, in_len]
    diff = sample.unsqueeze(1) - sample.unsqueeze(0)
    # Compute the Euclidean distance (squared)
    dist_squared = torch.sum(diff ** 2, dim=-1)  #[bsz, bsz]
    param = torch.max(dist_squared)
    euc_similarity = torch.exp(-5 * dist_squared /param )
    return euc_similarity.to(self.device)

How ever, it doesn't work in multiple time series. So I refer to "https://github.com/RakibulHaqueSajal/CCM_Module.git".
But still, the implementation of funciton get_similarity_matrix is ambiguous.
Can you tell us how to exactly implement this function?
Is this version correct?

def get_similarity_matrix(self, batch_data, sigma=5.0):
    x = batch_data.transpose(1, 2)
    cdist = torch.cdist(x, x, p=2.0)

    # Similarity matrix per input example
    sim_matrix = torch.exp(-cdist / (2 * sigma ** 2))
    return sim_matrix.mean(dim=0).to(self.device, dtype=torch.float32)

OR

def get_similarity_matrix(self, batch_data, sigma=5.0):
    batch_len, seq_len, num_channels = batch_data.shape
    similarity_matrix = torch.zeros((num_channels, num_channels), device=batch_data.device)

    # Compute point-by-point differences along the sequence length
    time_diffs = batch_data[:, 1:, :] - batch_data[:, :-1, :]  # Shape: (batch_len, seq_len-1, channel)

    # Compute mean of these differences over batch and sequence length
    channel_representations = time_diffs.mean(dim=(0, 1))  # Shape: (channel,)

    # Compute pairwise similarity
    for i in range(num_channels):
        for j in range(num_channels):
            diff = torch.norm(channel_representations[i] - channel_representations[j]) ** 2
            similarity_matrix[i, j] = torch.exp(-diff / (2 * sigma ** 2))

    return similarity_matrix.to(self.device)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions