Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions model/soft_dtw_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compute_softdtw_cuda(D, gamma, warp, bandwidth, max_i, max_j, n_passes, R):
cuda.syncthreads()

@cuda.jit
def compute_softdtw_backward_cuda(D, R, inv_gamma, warp, bandwidth, max_i, max_j, n_passes, E, G):
def compute_softdtw_backward_cuda(D, R, inv_gamma, warp, bandwidth, max_i, max_j, n_passes, E):
k = cuda.blockIdx.x
tid = cuda.threadIdx.x

Expand All @@ -84,7 +84,6 @@ def compute_softdtw_backward_cuda(D, R, inv_gamma, warp, bandwidth, max_i, max_j
b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j] - warp) * inv_gamma)
c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i, j]) * inv_gamma)
E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c
G[k, i, j] = E[k, i + 1, j]+E[k, i, j+1]+E[k, i+1, j+1]

cuda.syncthreads()

Expand Down Expand Up @@ -142,18 +141,16 @@ def backward(ctx, grad_output):
E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
E[:, -1, -1] = 1

G = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
G[:, -1, -1] = 1

compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_),
cuda.as_cuda_array(R),
1.0 / gamma.item(), warp.item(), bandwidth.item(), N, M, n_passes,
cuda.as_cuda_array(E), cuda.as_cuda_array(G))
G = G[:, 1:N + 1, 1:M + 1] # dR_D
cuda.as_cuda_array(E))
E = E[:, 1:N + 1, 1:M + 1] # dR_D

tmp_G = G.unsqueeze(-1).expand(-1, -1, -1, H)
tmp_G = tmp_G * torch.sign(raw_D)
dR_X = tmp_G.sum(dim=2)
tmp_E = E.unsqueeze(-1).expand(-1, -1, -1, H)
tmp_E = tmp_E * torch.sign(raw_D)
dR_X = tmp_E.sum(dim=2)

return grad_output.view(-1, 1, 1).expand_as(dR_X) * dR_X, None, None, None, None, None

Expand Down Expand Up @@ -194,10 +191,8 @@ def cpu_compute_softdtw_backward(D_, R, gamma, warp, bandwidth):
M = D_.shape[2]
D = np.zeros((B, N + 2, M + 2))
E = np.zeros((B, N + 2, M + 2))
G = np.zeros((B, N + 2, M + 2))
D[:, 1:N + 1, 1:M + 1] = D_
E[:, -1, -1] = 1
G[:, -1, -1] = 1
for k in range(B):
for j in range(M, 0, -1):
for i in range(N, 0, -1):
Expand All @@ -216,9 +211,8 @@ def cpu_compute_softdtw_backward(D_, R, gamma, warp, bandwidth):
b = np.exp(b0)
c = np.exp(c0)
E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c
G[k, i, j] = E[k, i + 1, j]+E[k, i, j+1]+E[k, i+1, j+1]

return G[:, 1:N + 1, 1:M + 1]
return E[:, 1:N + 1, 1:M + 1]


class CPUSoftDTW(Function):
Expand Down Expand Up @@ -249,10 +243,10 @@ def backward(ctx, grad_output):
g_ = gamma.item()
w_ = warp.item()
b_ = bandwidth.item()
G = torch.Tensor(cpu_compute_softdtw_backward(D_, R_, g_, w_, b_)).to(dev).type(dtype)
tmp_G = G.unsqueeze(-1).expand(-1, -1, -1, H)
tmp_G = tmp_G * torch.sign(raw_D)
dR_X = tmp_G.sum(dim=2)
E = torch.Tensor(cpu_compute_softdtw_backward(D_, R_, g_, w_, b_)).to(dev).type(dtype)
tmp_E = E.unsqueeze(-1).expand(-1, -1, -1, H)
tmp_E = tmp_E * torch.sign(raw_D)
dR_X = tmp_E.sum(dim=2)

return grad_output.view(-1, 1, 1).expand_as(dR_X) * dR_X, None, None, None, None, None

Expand Down Expand Up @@ -298,9 +292,10 @@ def _manhattan_dist_func(x, y):
return torch.abs(x - y).sum(3), (x - y)

def forward(self, X, Y):

n_hidden = X.size(-1)
func_dtw = self._get_func_dtw(X, Y)

D_xy, raw_D_xy = self.dist_func(X, Y)
D_xy = D_xy / n_hidden
return func_dtw(X, raw_D_xy, D_xy, self.gamma, self.warp, self.bandwidth)