@@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
1313 readonly axis : number ;
1414}
1515
16- const validateInputs = ( inputs : readonly TensorView [ ] ) : void => {
16+ const validateInputs = ( inputs : readonly TensorView [ ] , axis : number ) : void => {
1717 if ( ! inputs || inputs . length < 1 ) {
1818 throw new Error ( 'too few inputs' ) ;
1919 }
20-
21- const inputType = inputs [ 0 ] . dataType ;
22- const inputDimensionality = inputs [ 0 ] . dims . length ;
23-
24- for ( const input of inputs ) {
20+ const referenceIndex = 0 ;
21+ const referenceInput = inputs [ referenceIndex ] ;
22+ const inputType = referenceInput . dataType ;
23+ const inputRank = referenceInput . dims . length ;
24+ inputs . forEach ( ( input , i ) => {
25+ if ( i === referenceIndex ) {
26+ return ;
27+ }
2528 // make sure types of all inputs match
2629 if ( input . dataType !== inputType ) {
2730 throw new Error ( 'input tensors should be one type' ) ;
2831 }
29-
3032 // make sure the dimensionality of all inputs are the same
31- if ( input . dims . length !== inputDimensionality ) {
33+ if ( input . dims . length !== inputRank ) {
3234 throw new Error ( 'input tensors should have the same shape' ) ;
3335 }
34- }
36+ input . dims . forEach ( ( dim , i ) => {
37+ if ( i !== axis && dim !== referenceInput . dims [ i ] ) {
38+ throw new Error ( 'non concat dimensions must match' ) ;
39+ }
40+ } ) ;
41+ } ) ;
3542} ;
3643
3744const calculateInputIndexImpl = ( numberOfTensors : number , sizeInConcatAxisStr : string ) : string => `
@@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
6471 return codeLines . join ( '\n' ) ;
6572} ;
6673
67- const createConcatProgramInfo = ( inputs : readonly TensorView [ ] , axis : number ) : ProgramInfo => {
68- const inputShape = inputs [ 0 ] . dims . slice ( ) ;
69- if ( axis >= inputShape . length || axis < ( - 1 * inputShape . length ) ) {
70- throw new Error ( 'axis specified for concat doesn\'t match input dimensionality' ) ;
71- }
72- const adjustedAxis = ( axis < 0 ) ? inputShape . length + axis : axis ;
73- // ensure all of the non-concatenated axes match each other
74- // calculate the shape of the output tensor while we do that
75- const outputShape = inputShape . slice ( 0 ) ;
76- for ( let i = 1 ; i < inputs . length ; i ++ ) {
77- const dataNShape = inputs [ i ] . dims . slice ( ) ;
78- for ( let axisIndex = 0 ; axisIndex < inputShape . length ; axisIndex ++ ) {
79- // add to the placeholder for computing output shape
80- if ( axisIndex === adjustedAxis ) {
81- outputShape [ adjustedAxis ] += dataNShape [ axisIndex ] ;
74+ const createConcatProgramInfo =
75+ ( inputs : readonly TensorView [ ] , adjustedAxis : number , outputShape : number [ ] , dataType : DataType ) : ProgramInfo => {
76+ const outputSize = ShapeUtil . size ( outputShape ) ;
77+
78+ const sizeInConcatAxis = new Array < number > ( inputs . length ) ;
79+ const inputVars = new Array < IndicesHelper > ( inputs . length ) ;
80+
81+ let previousSum = 0 ;
82+ const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ ] ;
83+ const inputRanks = [ ] ;
84+ const programUniforms : ProgramUniform [ ] = [ { type : DataType . uint32 , data : outputSize } ] ;
85+ for ( let i = 0 ; i < inputs . length ; ++ i ) {
86+ previousSum += inputs [ i ] . dims [ adjustedAxis ] ;
87+ sizeInConcatAxis [ i ] = previousSum ;
88+ inputRanks . push ( inputs [ i ] . dims . length ) ;
89+ inputVars [ i ] = inputVariable ( `input${ i } ` , dataType , inputRanks [ i ] ) ;
90+ inputDependencies . push ( 'rank' ) ;
91+ programUniforms . push ( { type : DataType . uint32 , data : sizeInConcatAxis [ i ] } ) ;
8292 }
83- // ensure all non-cancatenated axes match each other
84- else if ( inputShape [ axisIndex ] !== dataNShape [ axisIndex ] ) {
85- throw new Error ( 'non concat dimensions must match' ) ;
93+ for ( let i = 0 ; i < inputs . length ; ++ i ) {
94+ programUniforms . push ( ...createTensorShapeVariables ( inputs [ i ] . dims ) ) ;
8695 }
87- }
88- }
89-
90- const outputSize = ShapeUtil . size ( outputShape ) ;
91-
92- const sizeInConcatAxis = new Array < number > ( inputs . length ) ;
93- const inputVars = new Array < IndicesHelper > ( inputs . length ) ;
94- const dataType = inputs [ 0 ] . dataType ;
95-
96- let previousSum = 0 ;
97- const inputDependencies : ProgramInputTensorInfoDependency [ ] = [ ] ;
98- const inputRanks = [ ] ;
99- const programUniforms : ProgramUniform [ ] = [ { type : DataType . uint32 , data : outputSize } ] ;
100- for ( let i = 0 ; i < inputs . length ; ++ i ) {
101- previousSum += inputs [ i ] . dims [ adjustedAxis ] ;
102- sizeInConcatAxis [ i ] = previousSum ;
103- inputRanks . push ( inputs [ i ] . dims . length ) ;
104- inputVars [ i ] = inputVariable ( `input${ i } ` , dataType , inputRanks [ i ] ) ;
105- inputDependencies . push ( 'rank' ) ;
106- programUniforms . push ( { type : DataType . uint32 , data : sizeInConcatAxis [ i ] } ) ;
107- }
108- for ( let i = 0 ; i < inputs . length ; ++ i ) {
109- programUniforms . push ( ...createTensorShapeVariables ( inputs [ i ] . dims ) ) ;
110- }
111- programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
96+ programUniforms . push ( ...createTensorShapeVariables ( outputShape ) ) ;
11297
113- const output = outputVariable ( 'output' , dataType , outputShape . length ) ;
114- const indicesAxis = output . indicesGet ( 'indices' , adjustedAxis ) ;
115- const sizeInConcatAxisStr =
116- Array . from ( Array ( sizeInConcatAxis . length ) . keys ( ) ) . map ( i => `uniforms.sizeInConcatAxis${ i } ` ) . join ( ',' ) ;
117- const getShaderSource = ( shaderHelper : ShaderHelper ) => `
98+ const output = outputVariable ( 'output' , dataType , outputShape . length ) ;
99+ const indicesAxis = output . indicesGet ( 'indices' , adjustedAxis ) ;
100+ const sizeInConcatAxisStr =
101+ Array . from ( Array ( sizeInConcatAxis . length ) . keys ( ) ) . map ( i => `uniforms.sizeInConcatAxis${ i } ` ) . join ( ',' ) ;
102+ const getShaderSource = ( shaderHelper : ShaderHelper ) => `
118103
119104 ${ ( ( ) => {
120- shaderHelper . registerUniform ( 'outputSize' , 'u32' ) ;
121- for ( let i = 0 ; i < inputs . length ; i ++ ) {
122- shaderHelper . registerUniform ( `sizeInConcatAxis${ i } ` , 'u32' ) ;
123- }
124- return shaderHelper . declareVariables ( ...inputVars , output ) ;
125- } ) ( ) }
105+ shaderHelper . registerUniform ( 'outputSize' , 'u32' ) ;
106+ for ( let i = 0 ; i < inputs . length ; i ++ ) {
107+ shaderHelper . registerUniform ( `sizeInConcatAxis${ i } ` , 'u32' ) ;
108+ }
109+ return shaderHelper . declareVariables ( ...inputVars , output ) ;
110+ } ) ( ) }
126111
127112 ${ calculateInputIndexImpl ( sizeInConcatAxis . length , sizeInConcatAxisStr ) }
128113
@@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
140125 ${ assignOutputData ( inputVars , output ) }
141126 }` ;
142127
143- return {
144- name : 'Concat' ,
145- shaderCache : { hint : `${ axis } ` , inputDependencies} ,
146- getRunData : ( ) => ( {
147- outputs : [ { dims : outputShape , dataType : inputs [ 0 ] . dataType } ] ,
148- dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
149- programUniforms,
150- } ) ,
151- getShaderSource,
152- } ;
153- } ;
128+ return {
129+ name : 'Concat' ,
130+ shaderCache : { hint : `${ adjustedAxis } ` , inputDependencies} ,
131+ getRunData : ( ) => ( {
132+ outputs : [ { dims : outputShape , dataType} ] ,
133+ dispatchGroup : { x : Math . ceil ( outputSize / 64 /* workgroup size */ ) } ,
134+ programUniforms,
135+ } ) ,
136+ getShaderSource,
137+ } ;
138+ } ;
154139
155140export const concat = ( context : ComputeContext , attributes : ConcatAttributes ) : void => {
156- validateInputs ( context . inputs ) ;
141+ const inputs = context . inputs ;
142+ const inputShape = inputs [ 0 ] . dims ;
143+ const adjustedAxis = ShapeUtil . normalizeAxis ( attributes . axis , inputShape . length ) ;
144+ validateInputs ( inputs , adjustedAxis ) ;
145+ const outputShape = inputShape . slice ( ) ;
146+ outputShape [ adjustedAxis ] =
147+ inputs . reduce ( ( sum , input ) => sum + ( input . dims . length > adjustedAxis ? input . dims [ adjustedAxis ] : 0 ) , 0 ) ;
157148 // 0 length tensors are valid for concat, remove them
158- const nonEmptyInputs = context . inputs . filter ( input => ShapeUtil . size ( input . dims ) > 0 ) ;
159- context . compute ( createConcatProgramInfo ( nonEmptyInputs , attributes . axis ) , { inputs : nonEmptyInputs } ) ;
149+ const nonEmptyInputs = inputs . filter ( input => ShapeUtil . size ( input . dims ) > 0 ) ;
150+ context . compute (
151+ createConcatProgramInfo ( nonEmptyInputs , adjustedAxis , outputShape , inputs [ 0 ] . dataType ) , { inputs : nonEmptyInputs } ) ;
160152} ;
161153
162154export const parseConcatAttributes = ( attributes : Record < string , unknown > ) : ConcatAttributes =>
0 commit comments