Skip to content

Commit d6b945c

Browse files
committed
Feature: Dot products with two constant vectors (or one with a single unit value) are simplified.
e.g. "f = dot(vec2(1, 2), vec2(3,4))" => "f = 11.0" "f = dot(vec2(0, 1), v)" => "f = v.y"
1 parent 9c8140b commit d6b945c

File tree

5 files changed

+74
-2
lines changed

5 files changed

+74
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ E.g.
567567
* Change ```pow(3.0, 2.0)``` to ```9.0```
568568
* Change ```float a = 1.2 / 2.3 * 4.5;``` to ```float a = 2.3478;```
569569
* Change ```vec2 f = vec2(1.1, 2.2) + 3.3 * 4.4;``` to ```vec2 f = vec2(15.62, 16.72);```
570+
* Change ```float f = dot(v, vec3(0, 1, 0));``` to ```float f = v.y;```
570571

571572
---
572573
## Replace Functions Calls With Result

ShaderShrinker/Shrinker.Parser/Optimizations/PerformArithmeticExtension.cs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,45 @@ public static bool PerformArithmetic(this SyntaxNode rootNode)
195195
didChange = true;
196196
}
197197

198+
// dot(vecN(nums), vecN(nums)) => <the result>
199+
foreach (var dotNode in functionCalls
200+
.Where(o => o.Name == "dot" && o.Params.Children.Any(n => (n.Token as TypeToken)?.IsVector() == true))
201+
.ToList())
202+
{
203+
var dotParams = dotNode.Params.GetCsv().ToList();
204+
var vectors = dotParams.Select(o => o.First()).Where(o => (o.Token as TypeToken)?.IsVector() == true).ToList();
205+
var nums = vectors.Select(GetVectorNumericCsv).ToList();
206+
if (nums.All(o => o == null))
207+
continue; // Neither arg is a simple numeric vector.
208+
209+
if (nums.Count(o => o != null) == 2)
210+
{
211+
// Both args are vectors and simple numeric - We can calculate the result.
212+
var sum = nums[0].Select((t, i) => t * nums[1][i]).Sum();
213+
dotNode.ReplaceWith(new GenericSyntaxNode(FloatToken.From(sum, MaxDp)));
214+
didChange = true;
215+
continue;
216+
}
217+
218+
// Do any of the vectors have only a single '1' component?
219+
var vectorWithAOne = vectors.FirstOrDefault(o => IsSingleElementOne(GetVectorNumericCsv(o)));
220+
if (vectorWithAOne == null)
221+
continue; // Nope.
222+
223+
// Yes - The 'other' param just needs a .x/.y/.z suffix.
224+
var otherParam = dotParams.Single(o => o[0] != vectorWithAOne);
225+
226+
// But first check it a single variable...
227+
if (otherParam.Count != 1 || otherParam.Single() is not GenericSyntaxNode node)
228+
continue;
229+
230+
// Replace the dot().
231+
var oneIndex = GetVectorNumericCsv(vectorWithAOne).IndexOf(1.0);
232+
var newNode = new GenericSyntaxNode($"{node.Token.Content}.{"xyzw"[oneIndex]}");
233+
dotNode.ReplaceWith(newNode);
234+
didChange = true;
235+
}
236+
198237
// Constant math/trig functions => <the result>
199238
var mathOp = new List<Tuple<string, Func<double, double>>>
200239
{
@@ -430,6 +469,11 @@ private static List<double> GetVectorNumericCsv(SyntaxNode vectorNode)
430469
return nums;
431470
}
432471

472+
private static bool IsSingleElementOne(IList<double> values) =>
473+
values != null &&
474+
values.All(o => o is 0.0 or 1.0) &&
475+
Math.Abs(values.Count(o => Math.Abs(o - 1.0) < 0.00001) - 1.0) < 0.00001;
476+
433477
private static bool IsSafeToPerformMath(SyntaxNode lhsNumNode, SymbolOperatorToken symbol, SyntaxNode rhsNumNode)
434478
{
435479
if (lhsNumNode == null || symbol == null || rhsNumNode == null)

ShaderShrinker/Shrinker.WpfApp/OptionsDialog.xaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@
573573
* Change ```pow(3.0, 2.0)``` to ```9.0```
574574
* Change ```float a = 1.2 / 2.3 * 4.5;``` to ```float a = 2.3478;```
575575
* Change ```vec2 f = vec2(1.1, 2.2) + 3.3 * 4.4;``` to ```vec2 f = vec2(15.62, 16.72);```
576+
* Change ```float f = dot(v, vec3(0, 1, 0));``` to ```float f = v.y;```
576577
</MdXaml:MarkdownScrollViewer>
577578
</CheckBox.ToolTip>
578579
</CheckBox>

ShaderShrinker/UnitTests/ShrinkerTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,8 +2131,8 @@ public void CheckInliningFunctionsCalledWithConstNumericParams(
21312131
"mat2 main() { return mat2(.76031, .64956, -.64956, .76031); }",
21322132
"mat2 main() { return mat2(.85252, -.52269, .52269, .85252); }",
21332133
"int main() { return 12; }",
2134-
"float main() { return dot(vec2(1, 2), vec2(1, 2)); }",
2135-
"float main() { return dot(vec2(1, 2), vec2(1, 2)); }",
2134+
"float main() { return 5.; }",
2135+
"float main() { return 5.; }",
21362136
"vec3 f(vec3 p) { float c = cos(.1 * p.x), s = sin(.1 * p.x); return vec3(mat2(c, -s, s, c) * p.xy, p.z); } vec3 p; vec3 main() { return f(p) + vec3(1); }",
21372137
"float main() { return 0.; }",
21382138
"float main() { return vec3(0); }",

ShaderShrinker/UnitTests/VectorArithmeticTests.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,32 @@ public void CheckArithmeticWithFractFunction()
265265
Assert.That(rootNode.ToCode().ToSimple(), Is.EqualTo("float f = .9;"));
266266
}
267267

268+
[Test, Sequential]
269+
public void CheckArithmeticWithSingleIndexDotFunction(
270+
[Values("vec3 v; float f = dot(vec3(0, 0, 1), v);|vec3 v; float f = v.z;",
271+
"vec3 v; float f = dot(vec3(0, 1, 0), v);|vec3 v; float f = v.y;",
272+
"vec3 v; float f = dot(vec3(1, 0, 0), v);|vec3 v; float f = v.x;",
273+
"vec3 v; float f = dot(v, vec3(0, 1, 0));|vec3 v; float f = v.y;",
274+
"vec2 v; float f = dot(v, vec2(1, 0));|vec2 v; float f = v.x;",
275+
"float f = dot(vec2(1, 2), vec2(3, 4));|float f = 11.;",
276+
"float f = dot(vec2(2), vec2(3, 4));|float f = 14.;",
277+
"vec2 v; float f = dot(v + v, vec2(1, 0));|vec2 v; float f = dot(v + v, vec2(1, 0));")] string code)
278+
{
279+
var input = code.Split('|')[0];
280+
var expected = code.Split('|')[1];
281+
282+
var lexer = new Lexer();
283+
lexer.Load(input);
284+
285+
var options = CustomOptions.None();
286+
options.PerformArithmetic = true;
287+
var rootNode = new Parser(lexer)
288+
.Parse()
289+
.Simplify(options);
290+
291+
Assert.That(rootNode.ToCode().ToSimple(), Is.EqualTo(expected));
292+
}
293+
268294
[Test, Sequential]
269295
public void CheckArithmeticWithTrigFunctions(
270296
[Values("float f = sin(3.141);",

0 commit comments

Comments
 (0)