diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index a86eb3f5..cf9fd468 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -761,8 +761,8 @@ var _ = Describe("Simulator", func() { simulator.config.InterTokenLatency = interTokenLatency simulator.config.InterTokenLatencyStdDev = stddev interToken := simulator.getInterTokenLatency() - Expect(interToken).To(BeNumerically(">=", float32(interTokenLatency)*0.3)) - Expect(interToken).To(BeNumerically("<=", float32(interTokenLatency)*1.7)) + Expect(interToken).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3))) + Expect(interToken).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7))) }, func(interTokenLatency int, stddev int) string { return fmt.Sprintf("interTokenLatency: %d stddev: %d", interTokenLatency, stddev) @@ -778,8 +778,8 @@ var _ = Describe("Simulator", func() { simulator.config.InterTokenLatency = interTokenLatency simulator.config.InterTokenLatencyStdDev = stddev latency := simulator.getTotalInterTokenLatency(numberOfTokens) - Expect(latency).To(BeNumerically(">=", float32(interTokenLatency)*0.3*float32(numberOfTokens))) - Expect(latency).To(BeNumerically("<=", float32(interTokenLatency)*1.7*float32(numberOfTokens))) + Expect(latency).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3*float32(numberOfTokens)))) + Expect(latency).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7*float32(numberOfTokens)))) }, func(interTokenLatency int, stddev int, numberOfTokens int) string { return fmt.Sprintf("interTokenLatency: %d stddev: %d, numberOfTokens: %d", interTokenLatency, @@ -800,11 +800,11 @@ var _ = Describe("Simulator", func() { simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill) if doREmotePrefill { - Expect(timeToFirst).To(BeNumerically(">=", float32(kvCacheLatency)*0.3)) - Expect(timeToFirst).To(BeNumerically("<=", float32(kvCacheLatency)*1.7)) + Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) } else { - Expect(timeToFirst).To(BeNumerically(">=", float32(timeToFirstToken)*0.3)) - Expect(timeToFirst).To(BeNumerically("<=", float32(timeToFirstToken)*1.7)) + Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7))) } }, func(timeToFirstToken int, timeToFirstTokenStdDev int,