@@ -34,24 +34,14 @@ Base.length(a::Tensorizer) = mapreduce(sum,*,a.blocks)
34
34
35
35
36
36
function start (a:: TrivialTensorizer{d} ) where {d}
37
- if d== 2
38
- return invoke (start, Tuple{Tensorizer2D}, a)
39
- else
40
- # ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
41
- block = SVector {d} (Ones {Int} (d))
42
- return (block, (0 , nothing , nothing )), (0 ,length (a))
43
- end
37
+ # ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
38
+ block = SVector {d} (Ones {Int} (d))
39
+ return (block, (0 , nothing , nothing )), (0 ,length (a))
44
40
end
45
41
46
42
function next (a:: TrivialTensorizer{d} , iterator_tuple) where {d}
47
-
48
- if d== 2
49
- return invoke (next, Tuple{Tensorizer2D, Tuple}, a, iterator_tuple)
50
- end
51
-
52
43
(block, (j, iterator, iter_state)), (i,tot) = iterator_tuple
53
44
54
-
55
45
@inline function check_block_finished (j, iterator, block)
56
46
if iterator === nothing
57
47
return true
@@ -82,19 +72,22 @@ function next(a::TrivialTensorizer{d}, iterator_tuple) where {d}
82
72
end
83
73
84
74
85
- function done (a:: TrivialTensorizer{d} , iterator_tuple) where {d}
86
- if d== 2
87
- return invoke (done, Tuple{Tensorizer2D, Tuple}, a, iterator_tuple)
88
- end
89
- (_, (i,tot)) = iterator_tuple
75
+ function done (a:: TrivialTensorizer , iterator_tuple)
76
+ i, tot = last (iterator_tuple)
90
77
return i ≥ tot
91
78
end
92
79
93
80
94
81
# (blockrow,blockcol), (subrow,subcol), (rowshift,colshift), (numblockrows,numblockcols), (itemssofar, length)
95
- start (a:: Tensorizer2D{AA, BB} ) where {AA,BB} = (1 ,1 ), (1 ,1 ), (0 ,0 ), (a. blocks[1 ][1 ],a. blocks[2 ][1 ]), (0 ,length (a))
82
+ start (a:: Tensorizer2D ) = _start (a:: Tensorizer2D )
83
+ start (a:: TrivialTensorizer{2} ) = _start (a:: Tensorizer2D )
84
+
85
+ _start (a) = (1 ,1 ), (1 ,1 ), (0 ,0 ), (a. blocks[1 ][1 ],a. blocks[2 ][1 ]), (0 ,length (a))
86
+
87
+ next (a:: Tensorizer2D , state) = _next (a, state)
88
+ next (a:: TrivialTensorizer{2} , state) = _next (a, state)
96
89
97
- function next (a :: Tensorizer2D{AA, BB} , ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))) where {AA,BB}
90
+ function _next (a, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot)))
98
91
ret = k+ rsh,j+ csh
99
92
if k== n && j== m # end of block
100
93
if J == 1 || K == length (a. blocks[1 ]) # end of new block
@@ -118,8 +111,10 @@ function next(a::Tensorizer2D{AA, BB}, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))
118
111
ret, ((K,J), (k,j), (rsh,csh), (n,m), (i+ 1 ,tot))
119
112
end
120
113
114
+ done (a:: Tensorizer2D , state) = _done (a, state)
115
+ done (a:: TrivialTensorizer{2} , state) = _done (a, state)
121
116
122
- done (a :: Tensorizer2D , ((K,J), (k,j), (rsh,csh), (n,m) , (i,tot))) = i ≥ tot
117
+ _done (a , (_, _, _, _ , (i,tot))) = i ≥ tot
123
118
124
119
iterate (a:: Tensorizer ) = next (a, start (a))
125
120
function iterate (a:: Tensorizer , st)
580
575
function totensor (it:: Tensorizer ,M:: AbstractVector )
581
576
n= length (M)
582
577
B= block (it,n)
583
- ds = dimensions (it)
584
578
585
579
# ret=zeros(eltype(M),[sum(it.blocks[i][1:min(B.n[1],length(it.blocks[i]))]) for i=1:length(it.blocks)]...)
586
580
0 commit comments