Skip to content

Commit ee0536d

Browse files
committed
[MINOR] Change DNNLSTM to use MatrixBlockReshape
1 parent 5ff6274 commit ee0536d

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNLSTM.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,13 @@ public static long lstmGeneric(DnnParameters params) {
315315

316316
//store caches
317317
ifog = ifo.append(g, true);
318-
MatrixBlock cache_out_t = LibMatrixReorg.reshape(out, new MatrixBlock(), 1, cache_out.clen, true);
318+
MatrixBlock cache_out_t = out.reshape( 1, cache_out.clen, true);
319319
cache_out.leftIndexingOperations(cache_out_t, t, t,0, cache_out.clen - 1, null, MatrixObject.UpdateType.INPLACE );
320320

321-
MatrixBlock cache_c_t = LibMatrixReorg.reshape(c, new MatrixBlock(), 1, cache_c.clen, true);
321+
MatrixBlock cache_c_t = c.reshape(1,cache_c.clen, true);
322322
cache_c.leftIndexingOperations(cache_c_t, t, t,0, cache_c.clen - 1, null, MatrixObject.UpdateType.INPLACE );
323323

324-
MatrixBlock cache_ifog_t = LibMatrixReorg.reshape(ifog, new MatrixBlock(), 1, cache_ifog.clen, true);
324+
MatrixBlock cache_ifog_t = ifog.reshape(1, cache_ifog.clen, true);
325325
cache_ifog.leftIndexingOperations(cache_ifog_t, t, t,0,cache_ifog.clen - 1, null, MatrixObject.UpdateType.INPLACE );
326326
}
327327
return params.output.recomputeNonZeros();
@@ -373,9 +373,9 @@ public static long lstmBackwardGeneric(DnnParameters params) {
373373
dout_prev = dout.slice(0, dout.rlen-1, t*M, (t+1)*M - 1).binaryOperations(plus, dout_prev);
374374

375375
//load and reuse cached results from forward pass for the current time step
376-
MatrixBlock c_t = LibMatrixReorg.reshape(cache_c.slice(t, t, 0, cache_c.clen - 1), new MatrixBlock(), params.N, M, true);
377-
MatrixBlock c_prev = t==0 ? c0 : LibMatrixReorg.reshape(cache_c.slice(t - 1, t - 1, 0, cache_c.clen - 1), new MatrixBlock(), params.N, M, true);
378-
MatrixBlock ifog = LibMatrixReorg.reshape(cache_ifog.slice(t, t,0, cache_ifog.clen - 1), new MatrixBlock(), params.N, 4*M, true);
376+
MatrixBlock c_t = cache_c.slice(t, t, 0, cache_c.clen - 1).reshape( params.N, M, true);
377+
MatrixBlock c_prev = t==0 ? c0 : cache_c.slice(t - 1, t - 1, 0, cache_c.clen - 1).reshape(params.N, M, true);
378+
MatrixBlock ifog = cache_ifog.slice(t, t,0, cache_ifog.clen - 1).reshape(params.N, 4*M, true);
379379
MatrixBlock i = ifog.slice(0, ifog.rlen - 1, 0, M -1);
380380
MatrixBlock f = ifog.slice(0, ifog.rlen - 1, M, 2*M -1);
381381
MatrixBlock o = ifog.slice(0, ifog.rlen - 1, 2*M, 3*M -1);
@@ -422,7 +422,7 @@ public static long lstmBackwardGeneric(DnnParameters params) {
422422

423423
//load the current input vector and in the cached previous hidden state
424424
MatrixBlock x_t = x.slice(0, x.rlen - 1, t*params.D , (t+1)*params.D - 1);
425-
MatrixBlock out_prev = t==0 ? out0 : LibMatrixReorg.reshape(cache_out.slice(t - 1, t - 1, 0, cache_out.clen - 1), new MatrixBlock(), params.N, M, true);
425+
MatrixBlock out_prev = t==0 ? out0 : cache_out.slice(t - 1, t - 1, 0, cache_out.clen - 1).reshape( params.N, M, true);
426426

427427
//merge mm for dx and dout_prev: input = cbind(X_t, out_prev) # shape (N, D+M)
428428
MatrixBlock in_t = x_t.append(out_prev, true).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);

0 commit comments

Comments
 (0)