Skip to content

Commit 1743b01

Browse files
authored
Merge pull request #748 from mairooni/fix/fp16_conversion
Support FP32 to FP16 conversion across all backends
2 parents cf77420 + 4dce7c3 commit 1743b01

File tree

10 files changed

+439
-3
lines changed

10 files changed

+439
-3
lines changed

tornado-assembly/src/bin/tornado-test

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ __TEST_THE_WORLD__ = [
125125
## SPIR-V, OpenCL and PTX foundation tests
126126
TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestIntegers"),
127127
TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestFloats"),
128+
TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestHalfFloats"),
128129
TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestDoubles"),
129130
TestEntry("uk.ac.manchester.tornado.unittests.foundation.MultipleRuns"),
130131
TestEntry("uk.ac.manchester.tornado.unittests.foundation.TestLinearAlgebra"),

tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/lir/OCLLIRStmt.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,37 @@ public void emitCode(OCLCompilationResultBuilder crb, OCLAssembler asm) {
216216

217217
}
218218

219+
@Opcode("CONVERT_FLOAT_TO_HALF")
220+
public static class ConvertFloatToHalfStmt extends AbstractInstruction {
221+
222+
public static final LIRInstructionClass<ConvertFloatToHalfStmt> TYPE = LIRInstructionClass.create(ConvertFloatToHalfStmt.class);
223+
224+
@Use
225+
protected Value floatValue;
226+
@Def
227+
protected Value halfValue;
228+
229+
public ConvertFloatToHalfStmt(Value floatValue, Value halfValue) {
230+
super(TYPE);
231+
this.floatValue = floatValue;
232+
this.halfValue = halfValue;
233+
}
234+
235+
@Override
236+
public void emitCode(OCLCompilationResultBuilder crb, OCLAssembler asm) {
237+
asm.indent();
238+
asm.emitValue(crb, halfValue);
239+
asm.space();
240+
asm.assign();
241+
asm.space();
242+
asm.emit("(half) ");
243+
asm.emitValue(crb, floatValue);
244+
asm.delimiter();
245+
asm.eol();
246+
}
247+
248+
}
249+
219250
@Opcode("VADD_HALF")
220251
public static class VectorAddHalfStmt extends AbstractInstruction {
221252

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) 2025, APT Group, Department of Computer Science,
3+
* School of Engineering, The University of Manchester. All rights reserved.
4+
* Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved.
5+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
6+
*
7+
* This code is free software; you can redistribute it and/or modify it
8+
* under the terms of the GNU General Public License version 2 only, as
9+
* published by the Free Software Foundation.
10+
*
11+
* This code is distributed in the hope that it will be useful, but WITHOUT
12+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14+
* version 2 for more details (a copy is included in the LICENSE file that
15+
* accompanied this code).
16+
*
17+
* You should have received a copy of the GNU General Public License version
18+
* 2 along with this work; if not, write to the Free Software Foundation,
19+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20+
*
21+
*/
22+
package uk.ac.manchester.tornado.drivers.opencl.graal.nodes;
23+
24+
import jdk.vm.ci.meta.JavaKind;
25+
import jdk.vm.ci.meta.Value;
26+
import org.graalvm.compiler.core.common.LIRKind;
27+
import org.graalvm.compiler.core.common.type.StampFactory;
28+
import org.graalvm.compiler.graph.NodeClass;
29+
import org.graalvm.compiler.lir.Variable;
30+
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
31+
import org.graalvm.compiler.nodeinfo.NodeInfo;
32+
import org.graalvm.compiler.nodes.ValueNode;
33+
import org.graalvm.compiler.nodes.spi.LIRLowerable;
34+
import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
35+
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLKind;
36+
import uk.ac.manchester.tornado.drivers.opencl.graal.lir.OCLLIRStmt;
37+
38+
@NodeInfo
39+
public class OCLConvertFloatToHalf extends ValueNode implements LIRLowerable {
40+
41+
public static final NodeClass<OCLConvertFloatToHalf> TYPE = NodeClass.create(OCLConvertFloatToHalf.class);
42+
43+
@Input
44+
private ValueNode floatValueNode;
45+
46+
public OCLConvertFloatToHalf(ValueNode floatValueNode) {
47+
super(TYPE, StampFactory.forKind(JavaKind.Short));
48+
this.floatValueNode = floatValueNode;
49+
}
50+
51+
public void generate(NodeLIRBuilderTool generator) {
52+
LIRGeneratorTool tool = generator.getLIRGeneratorTool();
53+
Variable halfValue = tool.newVariable(LIRKind.value(OCLKind.HALF));
54+
Value floatValue = generator.operand(floatValueNode);
55+
tool.append(new OCLLIRStmt.ConvertFloatToHalfStmt(floatValue, halfValue));
56+
generator.setResult(this, halfValue);
57+
}
58+
59+
}

tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/graal/phases/TornadoHalfFloatReplacement.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.DivHalfNode;
5454
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.HalfFloatConstantNode;
5555
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.MultHalfNode;
56+
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLConvertFloatToHalf;
5657
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.OCLConvertHalfToFloat;
5758
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.ReadHalfFloatNode;
5859
import uk.ac.manchester.tornado.drivers.opencl.graal.nodes.SubHalfNode;
@@ -99,7 +100,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
99100

100101
// replace reads with halfFloat reads
101102
for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
102-
if (javaRead.successors().first() instanceof NewInstanceNode) {
103+
if (javaRead.successors().first() instanceof NewInstanceNode && javaRead.getReadKind() == JavaKind.Short) {
103104
NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
104105
if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
105106
if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
@@ -124,6 +125,14 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
124125
newInstanceNode.replaceAtUsages(valueInput);
125126
deleteFixed(newInstanceNode);
126127
deleteFixed(newHalfFloatInstance);
128+
} else if (newInstanceNode.successors().first() instanceof JavaReadNode readValue && readValue.getReadKind() == JavaKind.Float) {
129+
OCLConvertFloatToHalf convertFloatToHalf = new OCLConvertFloatToHalf(readValue);
130+
graph.addWithoutUnique(convertFloatToHalf);
131+
newInstanceNode.replaceAtUsages(convertFloatToHalf);
132+
for (NewHalfFloatInstance newHalfFloatInstance : readValue.usages().filter(NewHalfFloatInstance.class)) {
133+
deleteFixed(newHalfFloatInstance);
134+
}
135+
deleteFixed(newInstanceNode);
127136
}
128137
}
129138
}
@@ -264,6 +273,10 @@ private static ValueNode getHalfFloatValue(ValueNode halfFloatValue, StructuredG
264273
HalfFloatConstantNode halfFloatConstantNode = new HalfFloatConstantNode(floatValue);
265274
graph.addWithoutUnique(halfFloatConstantNode);
266275
return halfFloatConstantNode;
276+
} else if (halfFloatValue instanceof JavaReadNode javaReadNode && javaReadNode.getReadKind() == JavaKind.Float) {
277+
OCLConvertFloatToHalf convertFloatToHalf = new OCLConvertFloatToHalf(javaReadNode);
278+
graph.addWithoutUnique(convertFloatToHalf);
279+
return convertFloatToHalf;
267280
} else {
268281
return halfFloatValue;
269282
}

tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/lir/PTXLIRStmt.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,35 @@ public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) {
866866

867867
}
868868

869+
public static class ConvertFloatToHalfStmt extends AbstractInstruction {
870+
871+
public static final LIRInstructionClass<ConvertFloatToHalfStmt> TYPE = LIRInstructionClass.create(ConvertFloatToHalfStmt.class);
872+
873+
@Use
874+
protected Value floatValue;
875+
@Def
876+
protected Value halfValue;
877+
878+
public ConvertFloatToHalfStmt(Value floatValue, Value halfValue) {
879+
super(TYPE);
880+
this.floatValue = floatValue;
881+
this.halfValue = halfValue;
882+
}
883+
884+
@Override
885+
public void emitCode(PTXCompilationResultBuilder crb, PTXAssembler asm) {
886+
asm.emitSymbol(TAB);
887+
asm.emit(CONVERT + DOT + "rn" + DOT + "f16" + DOT + "f32");
888+
asm.emitSymbol(SPACE);
889+
asm.emitValue(halfValue);
890+
asm.emitSymbol(COMMA + SPACE);
891+
asm.emitValue(floatValue);
892+
asm.delimiter();
893+
asm.eol();
894+
}
895+
896+
}
897+
869898
@Opcode("LOCAL_MEMORY_ACCESS")
870899
public static class LocalMemoryAccessStmt extends AbstractInstruction {
871900

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) 2025, APT Group, Department of Computer Science,
3+
* School of Engineering, The University of Manchester. All rights reserved.
4+
* Copyright (c) 2009, 2017, Oracle and/or its affiliates. All rights reserved.
5+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
6+
*
7+
* This code is free software; you can redistribute it and/or modify it
8+
* under the terms of the GNU General Public License version 2 only, as
9+
* published by the Free Software Foundation.
10+
*
11+
* This code is distributed in the hope that it will be useful, but WITHOUT
12+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14+
* version 2 for more details (a copy is included in the LICENSE file that
15+
* accompanied this code).
16+
*
17+
* You should have received a copy of the GNU General Public License version
18+
* 2 along with this work; if not, write to the Free Software Foundation,
19+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20+
*
21+
*/
22+
package uk.ac.manchester.tornado.drivers.ptx.graal.nodes;
23+
24+
import jdk.vm.ci.meta.JavaKind;
25+
import jdk.vm.ci.meta.Value;
26+
import org.graalvm.compiler.core.common.LIRKind;
27+
import org.graalvm.compiler.core.common.type.StampFactory;
28+
import org.graalvm.compiler.graph.NodeClass;
29+
import org.graalvm.compiler.lir.Variable;
30+
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
31+
import org.graalvm.compiler.nodeinfo.NodeInfo;
32+
import org.graalvm.compiler.nodes.ValueNode;
33+
import org.graalvm.compiler.nodes.spi.LIRLowerable;
34+
import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
35+
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXKind;
36+
import uk.ac.manchester.tornado.drivers.ptx.graal.lir.PTXLIRStmt;
37+
38+
@NodeInfo
39+
public class PTXConvertFloatToHalf extends ValueNode implements LIRLowerable {
40+
41+
public static final NodeClass<PTXConvertFloatToHalf> TYPE = NodeClass.create(PTXConvertFloatToHalf.class);
42+
43+
@Input
44+
private ValueNode floatValueNode;
45+
46+
public PTXConvertFloatToHalf(ValueNode floatValueNode) {
47+
super(TYPE, StampFactory.forKind(JavaKind.Short));
48+
this.floatValueNode = floatValueNode;
49+
}
50+
51+
public void generate(NodeLIRBuilderTool generator) {
52+
LIRGeneratorTool tool = generator.getLIRGeneratorTool();
53+
Variable halfValue = tool.newVariable(LIRKind.value(PTXKind.F16));
54+
Value floatValue = generator.operand(floatValueNode);
55+
tool.append(new PTXLIRStmt.ConvertFloatToHalfStmt(floatValue, halfValue));
56+
generator.setResult(this, halfValue);
57+
}
58+
59+
}

tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/graal/phases/TornadoHalfFloatReplacement.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.AddHalfNode;
5454
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.HalfFloatConstantNode;
5555
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.MultHalfNode;
56+
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXConvertFloatToHalf;
5657
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXConvertHalfToFloat;
5758
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.PTXHalfFloatDivisionNode;
5859
import uk.ac.manchester.tornado.drivers.ptx.graal.nodes.ReadHalfFloatNode;
@@ -99,7 +100,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
99100

100101
// replace reads with halfFloat reads
101102
for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
102-
if (javaRead.successors().first() instanceof NewInstanceNode) {
103+
if (javaRead.successors().first() instanceof NewInstanceNode && javaRead.getReadKind() == JavaKind.Short) {
103104
NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
104105
if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
105106
if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
@@ -124,6 +125,14 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
124125
newInstanceNode.replaceAtUsages(valueInput);
125126
deleteFixed(newInstanceNode);
126127
deleteFixed(newHalfFloatInstance);
128+
} else if (newInstanceNode.successors().first() instanceof JavaReadNode readValue && readValue.getReadKind() == JavaKind.Float) {
129+
PTXConvertFloatToHalf convertFloatToHalf = new PTXConvertFloatToHalf(readValue);
130+
graph.addWithoutUnique(convertFloatToHalf);
131+
newInstanceNode.replaceAtUsages(convertFloatToHalf);
132+
for (NewHalfFloatInstance newHalfFloatInstance : readValue.usages().filter(NewHalfFloatInstance.class)) {
133+
deleteFixed(newHalfFloatInstance);
134+
}
135+
deleteFixed(newInstanceNode);
127136
}
128137
}
129138
}
@@ -262,6 +271,10 @@ private static ValueNode getHalfFloatValue(ValueNode halfFloatValue, StructuredG
262271
HalfFloatConstantNode halfFloatConstantNode = new HalfFloatConstantNode(floatValue);
263272
graph.addWithoutUnique(halfFloatConstantNode);
264273
return halfFloatConstantNode;
274+
} else if (halfFloatValue instanceof JavaReadNode javaReadNode && javaReadNode.getReadKind() == JavaKind.Float) {
275+
PTXConvertFloatToHalf convertFloatToHalf = new PTXConvertFloatToHalf(javaReadNode);
276+
graph.addWithoutUnique(convertFloatToHalf);
277+
return convertFloatToHalf;
265278
} else {
266279
return halfFloatValue;
267280
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* This file is part of Tornado: A heterogeneous programming framework:
3+
* https://github.com/beehive-lab/tornadovm
4+
*
5+
* Copyright (c) 2025, APT Group, Department of Computer Science,
6+
* School of Engineering, The University of Manchester. All rights reserved.
7+
* Copyright (c) 2009-2021, Oracle and/or its affiliates. All rights reserved.
8+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
9+
*
10+
* This code is free software; you can redistribute it and/or modify it
11+
* under the terms of the GNU General Public License version 2 only, as
12+
* published by the Free Software Foundation.
13+
*
14+
* This code is distributed in the hope that it will be useful, but WITHOUT
15+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
16+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
17+
* version 2 for more details (a copy is included in the LICENSE file that
18+
* accompanied this code).
19+
*
20+
* You should have received a copy of the GNU General Public License version
21+
* 2 along with this work; if not, write to the Free Software Foundation,
22+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
23+
*
24+
*/
25+
package uk.ac.manchester.tornado.drivers.spirv.graal.nodes;
26+
27+
import jdk.vm.ci.meta.JavaKind;
28+
import jdk.vm.ci.meta.Value;
29+
import org.graalvm.compiler.core.common.LIRKind;
30+
import org.graalvm.compiler.core.common.type.StampFactory;
31+
import org.graalvm.compiler.graph.NodeClass;
32+
import org.graalvm.compiler.lir.Variable;
33+
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
34+
import org.graalvm.compiler.nodeinfo.NodeInfo;
35+
import org.graalvm.compiler.nodes.ValueNode;
36+
import org.graalvm.compiler.nodes.spi.LIRLowerable;
37+
import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
38+
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVKind;
39+
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVLIRStmt;
40+
import uk.ac.manchester.tornado.drivers.spirv.graal.lir.SPIRVUnary;
41+
42+
@NodeInfo
43+
public class SPIRVConvertFloatToHalf extends ValueNode implements LIRLowerable {
44+
45+
public static final NodeClass<SPIRVConvertFloatToHalf> TYPE = NodeClass.create(SPIRVConvertFloatToHalf.class);
46+
47+
@Input
48+
private ValueNode floatValueNode;
49+
50+
public SPIRVConvertFloatToHalf(ValueNode floatValueNode) {
51+
super(TYPE, StampFactory.forKind(JavaKind.Short));
52+
this.floatValueNode = floatValueNode;
53+
}
54+
55+
public void generate(NodeLIRBuilderTool generator) {
56+
LIRGeneratorTool tool = generator.getLIRGeneratorTool();
57+
Variable halfValue = tool.newVariable(LIRKind.value(SPIRVKind.OP_TYPE_FLOAT_16));
58+
Value floatValue = generator.operand(floatValueNode);
59+
LIRKind lirKind = LIRKind.value(SPIRVKind.OP_TYPE_FLOAT_16);
60+
SPIRVUnary.CastOperations cast = new SPIRVUnary.CastFloatDouble(lirKind, halfValue, floatValue, SPIRVKind.OP_TYPE_FLOAT_16);
61+
tool.append(new SPIRVLIRStmt.AssignStmt(halfValue, cast));
62+
generator.setResult(this, halfValue);
63+
}
64+
}

tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/graal/phases/TornadoHalfFloatReplacement.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.AddHalfNode;
5959
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.MultHalfNode;
6060
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.ReadHalfFloatNode;
61+
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.SPIRVConvertFloatToHalf;
6162
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.SPIRVConvertHalfToFloat;
6263
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.SubHalfNode;
6364
import uk.ac.manchester.tornado.drivers.spirv.graal.nodes.WriteHalfFloatNode;
@@ -102,7 +103,7 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
102103

103104
// replace reads with halfFloat reads
104105
for (JavaReadNode javaRead : graph.getNodes().filter(JavaReadNode.class)) {
105-
if (javaRead.successors().first() instanceof NewInstanceNode) {
106+
if (javaRead.successors().first() instanceof NewInstanceNode && javaRead.getReadKind() == JavaKind.Short) {
106107
NewInstanceNode newInstanceNode = (NewInstanceNode) javaRead.successors().first();
107108
if (newInstanceNode.instanceClass().getAnnotation(HalfType.class) != null) {
108109
if (newInstanceNode.successors().first() instanceof NewHalfFloatInstance) {
@@ -127,6 +128,14 @@ protected void run(StructuredGraph graph, TornadoHighTierContext context) {
127128
newInstanceNode.replaceAtUsages(valueInput);
128129
deleteFixed(newInstanceNode);
129130
deleteFixed(newHalfFloatInstance);
131+
} else if (newInstanceNode.successors().first() instanceof JavaReadNode readValue && readValue.getReadKind() == JavaKind.Float) {
132+
SPIRVConvertFloatToHalf convertFloatToHalf = new SPIRVConvertFloatToHalf(readValue);
133+
graph.addWithoutUnique(convertFloatToHalf);
134+
newInstanceNode.replaceAtUsages(convertFloatToHalf);
135+
for (NewHalfFloatInstance newHalfFloatInstance : readValue.usages().filter(NewHalfFloatInstance.class)) {
136+
deleteFixed(newHalfFloatInstance);
137+
}
138+
deleteFixed(newInstanceNode);
130139
}
131140
}
132141
}
@@ -265,6 +274,10 @@ private static ValueNode getHalfFloatValue(ValueNode halfFloatValue, StructuredG
265274
HalfFloatConstantNode halfFloatConstantNode = new HalfFloatConstantNode(floatValue);
266275
graph.addWithoutUnique(halfFloatConstantNode);
267276
return halfFloatConstantNode;
277+
} else if (halfFloatValue instanceof JavaReadNode javaReadNode && javaReadNode.getReadKind() == JavaKind.Float) {
278+
SPIRVConvertFloatToHalf convertFloatToHalf = new SPIRVConvertFloatToHalf(javaReadNode);
279+
graph.addWithoutUnique(convertFloatToHalf);
280+
return convertFloatToHalf;
268281
} else {
269282
return halfFloatValue;
270283
}

0 commit comments

Comments
 (0)