@@ -315,13 +315,13 @@ public static long lstmGeneric(DnnParameters params) {
315315
316316 //store caches
317317 ifog = ifo .append (g , true );
318- MatrixBlock cache_out_t = LibMatrixReorg .reshape (out , new MatrixBlock (), 1 , cache_out .clen , true );
318+ MatrixBlock cache_out_t = out .reshape ( 1 , cache_out .clen , true );
319319 cache_out .leftIndexingOperations (cache_out_t , t , t ,0 , cache_out .clen - 1 , null , MatrixObject .UpdateType .INPLACE );
320320
321- MatrixBlock cache_c_t = LibMatrixReorg .reshape (c , new MatrixBlock (), 1 , cache_c .clen , true );
321+ MatrixBlock cache_c_t = c .reshape (1 , cache_c .clen , true );
322322 cache_c .leftIndexingOperations (cache_c_t , t , t ,0 , cache_c .clen - 1 , null , MatrixObject .UpdateType .INPLACE );
323323
324- MatrixBlock cache_ifog_t = LibMatrixReorg .reshape (ifog , new MatrixBlock (), 1 , cache_ifog .clen , true );
324+ MatrixBlock cache_ifog_t = ifog .reshape (1 , cache_ifog .clen , true );
325325 cache_ifog .leftIndexingOperations (cache_ifog_t , t , t ,0 ,cache_ifog .clen - 1 , null , MatrixObject .UpdateType .INPLACE );
326326 }
327327 return params .output .recomputeNonZeros ();
@@ -373,9 +373,9 @@ public static long lstmBackwardGeneric(DnnParameters params) {
373373 dout_prev = dout .slice (0 , dout .rlen -1 , t *M , (t +1 )*M - 1 ).binaryOperations (plus , dout_prev );
374374
375375 //load and reuse cached results from forward pass for the current time step
376- MatrixBlock c_t = LibMatrixReorg . reshape ( cache_c .slice (t , t , 0 , cache_c .clen - 1 ), new MatrixBlock (), params .N , M , true );
377- MatrixBlock c_prev = t ==0 ? c0 : LibMatrixReorg . reshape ( cache_c .slice (t - 1 , t - 1 , 0 , cache_c .clen - 1 ), new MatrixBlock (), params .N , M , true );
378- MatrixBlock ifog = LibMatrixReorg . reshape ( cache_ifog .slice (t , t ,0 , cache_ifog .clen - 1 ), new MatrixBlock (), params .N , 4 *M , true );
376+ MatrixBlock c_t = cache_c .slice (t , t , 0 , cache_c .clen - 1 ). reshape ( params .N , M , true );
377+ MatrixBlock c_prev = t ==0 ? c0 : cache_c .slice (t - 1 , t - 1 , 0 , cache_c .clen - 1 ). reshape ( params .N , M , true );
378+ MatrixBlock ifog = cache_ifog .slice (t , t ,0 , cache_ifog .clen - 1 ). reshape ( params .N , 4 *M , true );
379379 MatrixBlock i = ifog .slice (0 , ifog .rlen - 1 , 0 , M -1 );
380380 MatrixBlock f = ifog .slice (0 , ifog .rlen - 1 , M , 2 *M -1 );
381381 MatrixBlock o = ifog .slice (0 , ifog .rlen - 1 , 2 *M , 3 *M -1 );
@@ -422,7 +422,7 @@ public static long lstmBackwardGeneric(DnnParameters params) {
422422
423423 //load the current input vector and in the cached previous hidden state
424424 MatrixBlock x_t = x .slice (0 , x .rlen - 1 , t *params .D , (t +1 )*params .D - 1 );
425- MatrixBlock out_prev = t ==0 ? out0 : LibMatrixReorg . reshape ( cache_out .slice (t - 1 , t - 1 , 0 , cache_out .clen - 1 ), new MatrixBlock (), params .N , M , true );
425+ MatrixBlock out_prev = t ==0 ? out0 : cache_out .slice (t - 1 , t - 1 , 0 , cache_out .clen - 1 ). reshape ( params .N , M , true );
426426
427427 //merge mm for dx and dout_prev: input = cbind(X_t, out_prev) # shape (N, D+M)
428428 MatrixBlock in_t = x_t .append (out_prev , true ).reorgOperations (transpose , new MatrixBlock (), 0 , 0 , 0 );
0 commit comments