Skip to content

Commit 4f26639

Browse files
feat: Lay groundwork for deep type detection from functions
This first pass starts with static values and gives us the mvp of type setting/detection from other functions.
1 parent 43bb07e commit 4f26639

File tree

10 files changed

+268
-37
lines changed

10 files changed

+268
-37
lines changed

src/backend/cpu/function-node.js

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ module.exports = class CPUFunctionNode extends BaseFunctionNode {
8888
* @returns {Array} the append retArr
8989
*/
9090
astFunctionDeclaration(ast, retArr) {
91-
if (this.addFunction) {
92-
this.addFunction(null, utils.getAstString(this.jsFunctionString, ast));
93-
}
91+
this.builder.addFunction(null, utils.getAstString(this.jsFunctionString, ast));
9492
return retArr;
9593
}
9694

@@ -595,7 +593,7 @@ module.exports = class CPUFunctionNode extends BaseFunctionNode {
595593
*/
596594
astExpressionStatement(esNode, retArr) {
597595
this.astGeneric(esNode.expression, retArr);
598-
retArr.push(';\n');
596+
retArr.push(';');
599597
return retArr;
600598
}
601599

src/backend/function-builder-base.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ module.exports = class FunctionBuilderBase {
4040
addFunction(functionName, jsFunction, options) {
4141
this.addFunctionNode(
4242
new this.Node(functionName, jsFunction, options)
43-
.setAddFunction(this.addFunction.bind(this))
43+
.setBuilder(this)
4444
);
4545
}
4646

@@ -153,7 +153,7 @@ module.exports = class FunctionBuilderBase {
153153
*/
154154
addKernel(fnString, options) {
155155
const kernelNode = new this.Node('kernel', fnString, options);
156-
kernelNode.setAddFunction(this.addFunction.bind(this));
156+
kernelNode.setBuilder(this);
157157
kernelNode.isRootKernel = true;
158158
this.addFunctionNode(kernelNode);
159159
return kernelNode;
@@ -174,7 +174,7 @@ module.exports = class FunctionBuilderBase {
174174
*/
175175
addSubKernel(jsFunction, options) {
176176
const kernelNode = new this.Node(null, jsFunction, options);
177-
kernelNode.setAddFunction(this.addFunction.bind(this));
177+
kernelNode.setBuilder(this);
178178
kernelNode.isSubKernel = true;
179179
this.addFunctionNode(kernelNode);
180180
return kernelNode;

src/backend/function-node-base.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ module.exports = class BaseFunctionNode {
2929
constructor(functionName, jsFunction, options) {
3030
this.calledFunctions = [];
3131
this.calledFunctionsArguments = {};
32-
this.addFunction = null;
32+
this.builder = null;
3333
this.isRootKernel = false;
3434
this.isSubKernel = false;
3535
this.parent = null;
@@ -164,8 +164,8 @@ module.exports = class BaseFunctionNode {
164164
return this.paramTypes[this.paramNames.indexOf(paramName)] === 'Input';
165165
}
166166

167-
setAddFunction(fn) {
168-
this.addFunction = fn;
167+
setBuilder(builder) {
168+
this.builder = builder;
169169
return this;
170170
}
171171

src/backend/kernel-base.js

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ module.exports = class KernelBase {
110110
}
111111
}
112112

113-
setAddFunction(cb) {
114-
this.addFunction = cb;
115-
return this;
116-
}
117-
118113
setFunctions(functions) {
119114
this.functions = functions;
120115
return this;

src/backend/web-gl/function-node.js

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
5656
* @returns {Array} the append retArr
5757
*/
5858
astFunctionDeclaration(ast, retArr) {
59-
if (this.addFunction) {
60-
this.addFunction(null, utils.getAstString(this.jsFunctionString, ast));
61-
}
59+
this.builder.addFunction(null, utils.getAstString(this.jsFunctionString, ast));
6260
return retArr;
6361
}
6462

@@ -144,6 +142,15 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
144142
case 'Array':
145143
retArr.push('sampler2D');
146144
break;
145+
case 'vec2':
146+
retArr.push('vec2');
147+
break;
148+
case 'vec3':
149+
retArr.push('vec3');
150+
break;
151+
case 'vec4':
152+
retArr.push('vec4');
153+
break;
147154
default:
148155
retArr.push('float');
149156
}
@@ -417,7 +424,7 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
417424
} else {
418425
this.astGeneric(forNode.body, retArr);
419426
}
420-
retArr.push('} else {\n');
427+
retArr.push('\n} else {\n');
421428
retArr.push('break;\n');
422429
retArr.push('}\n');
423430
retArr.push('}\n');
@@ -618,7 +625,7 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
618625
*/
619626
astExpressionStatement(esNode, retArr) {
620627
this.astGeneric(esNode.expression, retArr);
621-
retArr.push(';\n');
628+
retArr.push(';');
622629
return retArr;
623630
}
624631

@@ -642,17 +649,30 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
642649
}
643650
const retDeclaration = [];
644651
this.astGeneric(declaration, retDeclaration);
645-
if (retDeclaration[2] === 'getImage2D(' || retDeclaration[2] === 'getImage3D(') {
646-
if (i === 0) {
647-
retArr.push('vec4 ');
648-
}
649-
this.declarations[declaration.id.name] = 'vec4';
650-
} else {
651-
if (i === 0) {
652-
retArr.push('float ');
653-
}
654-
this.declarations[declaration.id.name] = 'float';
652+
let type = 'float';
653+
switch (retDeclaration[2]) {
654+
case 'getImage2D(':
655+
case 'getImage3D(':
656+
if (i === 0) {
657+
retArr.push('vec4 ');
658+
}
659+
type = 'vec4';
660+
break;
661+
default:
662+
if (i === 0) {
663+
if (declaration.init && declaration.init.name && this.declarations[declaration.init.name]) {
664+
type = this.declarations[declaration.init.name];
665+
retArr.push(type + ' ');
666+
} else if (declaration.init && declaration.init.type && declaration.init.type === 'ArrayExpression') {
667+
type = 'vec' + declaration.init.elements.length;
668+
retArr.push(type + ' ');
669+
} else {
670+
retArr.push('float ');
671+
}
672+
}
673+
break;
655674
}
675+
this.declarations[declaration.id.name] = type;
656676
retArr.push.apply(retArr, retDeclaration);
657677
}
658678
retArr.push(';');
@@ -891,7 +911,11 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
891911
// This normally refers to the global read only input vars
892912
let variableType = null;
893913
if (mNode.object.name) {
894-
variableType = this.getParamType(mNode.object.name);
914+
if (this.declarations[mNode.object.name]) {
915+
variableType = this.declarations[mNode.object.name];
916+
} else {
917+
variableType = this.getParamType(mNode.object.name);
918+
}
895919
} else if (
896920
mNode.object &&
897921
mNode.object.object &&
@@ -901,6 +925,8 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
901925
variableType = this.getConstantType(mNode.object.property.name);
902926
}
903927
switch (variableType) {
928+
case 'vec2':
929+
case 'vec3':
904930
case 'vec4':
905931
// Get from local vec4
906932
this.astGeneric(mNode.object, retArr);
@@ -1191,7 +1217,7 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase {
11911217
astArrayExpression(arrNode, retArr) {
11921218
const arrLen = arrNode.elements.length;
11931219

1194-
retArr.push('float[' + arrLen + '](');
1220+
retArr.push('vec' + arrLen + '(');
11951221
for (let i = 0; i < arrLen; ++i) {
11961222
if (i > 0) {
11971223
retArr.push(', ');

src/core/gpu.js

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,13 @@ class GPU extends GPUCore {
290290
* @memberOf GPU#
291291
*
292292
* @param {Function|String} fn - JS Function to do conversion
293-
* @param {String[]|Object} paramTypes - Parameter type array, assumes all parameters are 'float' if null
294-
* @param {String} returnType - The return type, assumes 'float' if null
293+
* @param {Object} options
295294
*
296295
* @returns {GPU} returns itself
297296
*
298297
*/
299-
addFunction(fn, paramTypes, returnType) {
300-
this._runner.functionBuilder.addFunction(null, fn, paramTypes, returnType);
298+
addFunction(fn, options) {
299+
this._runner.functionBuilder.addFunction(null, fn, options);
301300
return this;
302301
}
303302

test/all.html

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
<!-- features -->
2020
<script src="features/add-custom-function.js"></script>
2121
<script src="features/add-custom-native-function.js"></script>
22+
<script src="features/add-typed-functions.js"></script>
2223
<script src="features/combine-kernels.js"></script>
2324
<script src="features/constants-array.js"></script>
2425
<script src="features/constants-float.js"></script>
@@ -49,6 +50,7 @@
4950

5051
<!-- internal -->
5152
<script src="internal/context-inheritance.js"></script>
53+
<script src="internal/deep-type-detection.js"></script>
5254
<script src="internal/function-builder.js"></script>
5355
<script src="internal/function-node.js"></script>
5456
<script src="internal/kernel-base.js"></script>

test/features/add-typed-functions.js

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
(function() {
2+
function vec2Test(mode) {
3+
var gpu = new GPU({ mode: mode });
4+
function typedFunction() {
5+
return [1, 2];
6+
}
7+
gpu.addFunction(typedFunction, {
8+
returnType: 'vec2'
9+
});
10+
var kernel = gpu.createKernel(function() {
11+
var result = typedFunction();
12+
return result[0] + result[1];
13+
})
14+
.setOutput([1]);
15+
var result = kernel();
16+
QUnit.assert.equal(result[0], 3);
17+
}
18+
19+
QUnit.test( 'add typed functions - vec2 - (auto)', function() {
20+
vec2Test(null);
21+
});
22+
})();

0 commit comments

Comments
 (0)