Skip to content

Commit a2cf7ed

Browse files
committed
tests and refactoring
1 parent a16654c commit a2cf7ed

File tree

7 files changed

+277
-246
lines changed

7 files changed

+277
-246
lines changed

internal/weights/weights.go

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ import (
88

99
const NOT_SELECTED_WEIGHT = -2
1010

11-
// WEIGHT_SATURATION_LIMIT is set to 2^30.
12-
// Weights for periods and selected variables increase exponentially
13-
// with the number of periods and selections.
14-
const WEIGHT_SATURATION_LIMIT = 1073741824
11+
// WEIGHTS_SATURATION_LIMIT is set to 2^32.
12+
const WEIGHTS_SATURATION_LIMIT = 4294967296
1513

1614
type Weights map[string]int
1715

@@ -44,8 +42,9 @@ func (w Weights) absMaxWeight() int {
4442
return maxWeight
4543
}
4644

47-
func (w Weights) ContainsTooLargeWeight() bool {
48-
tooLarge := w.absMaxWeight() > WEIGHT_SATURATION_LIMIT
45+
func (w Weights) WeightsToLarge() bool {
46+
sum := w.sum()
47+
tooLarge := abs(sum) > WEIGHTS_SATURATION_LIMIT
4948
return tooLarge
5049
}
5150

@@ -68,13 +67,13 @@ func Calculate(
6867
notSelectedSum,
6968
preferredSum,
7069
)
71-
absMaxPeriodWeight := periodWeights.absMaxWeight()
70+
maxPeriodWeight := periodWeights.absMaxWeight()
7271

7372
selectedWeights := calculateSelectedWeights(
7473
selections,
7574
notSelectedSum,
7675
preferredSum,
77-
absMaxPeriodWeight,
76+
maxPeriodWeight,
7877
)
7978

8079
weights := notSelectedWeights.
@@ -114,15 +113,21 @@ func calculatePreferredWeights(
114113

115114
func calculatePeriodWeights(
116115
periodIDs []string,
117-
notSelectedSum,
118-
preferredSum int,
116+
notSelectedSum int,
117+
preferredWeightsSum int,
119118
) Weights {
120119
periodWeights := make(Weights)
121120

122-
threshold := -absSum(notSelectedSum, preferredSum)
121+
threshold := -absSum(notSelectedSum, preferredWeightsSum)
123122

124123
periodWeightSum := threshold
125-
for _, periodID := range periodIDs {
124+
for i, periodID := range periodIDs {
125+
if i == 0 {
126+
periodWeights[periodID] = 0
127+
128+
continue
129+
}
130+
126131
weight := periodWeightSum - 1
127132
periodWeights[periodID] = weight
128133
periodWeightSum += weight
@@ -135,11 +140,11 @@ func calculateSelectedWeights(
135140
selections Selections,
136141
notSelectedSum,
137142
preferredWeightsSum int,
138-
absMaxPeriodWeight int,
143+
maxPeriodWeight int,
139144
) Weights {
140145
selectedWeights := make(Weights)
141146

142-
threshold := absSum(notSelectedSum, preferredWeightsSum, absMaxPeriodWeight)
147+
threshold := absSum(notSelectedSum, preferredWeightsSum, maxPeriodWeight)
143148

144149
selectionWeightSum := threshold
145150
for _, selection := range selections {

internal/weights/weights_test.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,11 @@ func Test_absSum(t *testing.T) {
280280
terms []int
281281
expected int
282282
}{
283-
{terms: []int{1, -2, 3}, expected: 6},
284-
{terms: []int{-1, -2, -3}, expected: 6},
283+
{terms: []int{}, expected: 0},
285284
{terms: []int{0, 0, 0}, expected: 0},
285+
{terms: []int{-1, -2, -3}, expected: 6},
286+
{terms: []int{1, 2, 3}, expected: 6},
287+
{terms: []int{-1, -2, -3}, expected: 6},
286288
}
287289

288290
for _, theory := range theories {
@@ -308,26 +310,26 @@ func Test_calculatePeriodWeights(t *testing.T) {
308310
name: "given periodIDs and zero notSelectedSum and preferredSum",
309311
periodIDs: []string{"a", "b"},
310312
want: Weights{
311-
"a": -1,
312-
"b": -2,
313+
"a": 0,
314+
"b": -1,
313315
},
314316
},
315317
{
316318
name: "given periodIDs and non-zero positive notSelectedSum",
317319
periodIDs: []string{"a", "b"},
318320
notSelectedSum: 1,
319321
want: Weights{
320-
"a": -2,
321-
"b": -4,
322+
"a": 0,
323+
"b": -2,
322324
},
323325
},
324326
{
325327
name: "given periodIDs and non-zero negative notSelectedSum",
326328
periodIDs: []string{"a", "b"},
327329
notSelectedSum: -1,
328330
want: Weights{
329-
"a": -2,
330-
"b": -4,
331+
"a": 0,
332+
"b": -2,
331333
},
332334
},
333335
{
@@ -336,8 +338,8 @@ func Test_calculatePeriodWeights(t *testing.T) {
336338
notSelectedSum: -1,
337339
preferredSum: -2,
338340
want: Weights{
339-
"a": -4,
340-
"b": -8,
341+
"a": 0,
342+
"b": -4,
341343
},
342344
},
343345
}

puan/solution_creator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func (c *SolutionCreator) findDependentSolution(
7070
return SolutionEnvelope{}, err
7171
}
7272

73-
tooLarge := query.weights.ContainsTooLargeWeight()
73+
tooLarge := query.weights.WeightsToLarge()
7474

7575
solution, err := c.Solve(query)
7676
if err != nil {

tests/integration_tests/solve/period_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"testing"
55
"time"
66

7+
"github.com/go-faker/faker/v4/pkg/options"
78
"github.com/stretchr/testify/assert"
89

10+
"github.com/ourstudio-se/puan-sdk-go/internal/fake"
911
"github.com/ourstudio-se/puan-sdk-go/puan"
1012
)
1113

@@ -454,3 +456,121 @@ func Test_givenTimeEnabledWithoutTimeboundConstraints_andLateFromSpecified_shoul
454456
_, err := solutionCreator.Create(nil, ruleset, &afterEnd)
455457
assert.Error(t, err)
456458
}
459+
460+
// Global XOR rule for item1 and item2,
461+
// item2 has many consequences in the first period.
462+
// The solver should choose the first period when item2 is selected.
463+
func Test_givenXORWithManyConsequencesInFirstPeriod_shouldChooseFirstPeriod(t *testing.T) {
464+
creator := puan.NewRulesetCreator()
465+
startTime := time.Now()
466+
endTime := startTime.Add(1 * time.Hour)
467+
_ = creator.EnableTime(startTime, endTime)
468+
469+
item1 := fake.New[string]()
470+
item2 := fake.New[string]()
471+
_ = creator.AddPrimitives(item1, item2)
472+
xorID, _ := creator.SetXor(item1, item2)
473+
_ = creator.Assume(xorID)
474+
475+
item2Consequences := fake.New[[]string](
476+
func(oo *options.Options) {
477+
oo.RandomMinSliceSize = 50
478+
oo.RandomMaxSliceSize = 50
479+
},
480+
)
481+
_ = creator.AddPrimitives(item2Consequences...)
482+
483+
andID, _ := creator.SetAnd(item2Consequences...)
484+
item2Implies, _ := creator.SetImply(item2, andID)
485+
486+
endOfFirstPeriod := startTime.Add(30 * time.Minute)
487+
_ = creator.AssumeInPeriod(item2Implies, startTime, endOfFirstPeriod)
488+
489+
ruleset, _ := creator.Create()
490+
491+
envelope, _ := solutionCreator.Create(
492+
puan.Selections{
493+
puan.NewSelectionBuilder(item2).Build(),
494+
},
495+
ruleset,
496+
&startTime,
497+
)
498+
499+
solution := envelope.Solution()
500+
501+
asserter := newSolutionAsserter(solution)
502+
asserter.assertActive(t, item2)
503+
asserter.assertActive(t, "period_0")
504+
asserter.assertActive(t, item2Consequences...)
505+
asserter.assertInactive(t, item1)
506+
asserter.assertInactive(t, "period_1")
507+
}
508+
509+
// Global XOR rules for item1 and many other items,
510+
// all other items are preferred in the first period.
511+
// The solver should choose the first period when item1 is selected.
512+
func Test_givenXORWithManyPreferredInFirstPeriod_shouldChooseFirstPeriod(t *testing.T) {
513+
creator := puan.NewRulesetCreator()
514+
startTime := time.Now()
515+
endTime := startTime.Add(1 * time.Hour)
516+
_ = creator.EnableTime(startTime, endTime)
517+
518+
item1 := fake.New[string]()
519+
_ = creator.AddPrimitives(item1)
520+
521+
otherItems := fake.New[[]string](
522+
func(oo *options.Options) {
523+
oo.RandomMinSliceSize = 50
524+
oo.RandomMaxSliceSize = 50
525+
},
526+
)
527+
_ = creator.AddPrimitives(otherItems...)
528+
529+
endOfFirstPeriod := startTime.Add(30 * time.Minute)
530+
for _, otherItem := range otherItems {
531+
xorID, _ := creator.SetXor(item1, otherItem)
532+
_ = creator.Assume(xorID)
533+
preferredOtherItem, _ := creator.SetImply(xorID, otherItem)
534+
_ = creator.PreferInPeriod(preferredOtherItem, startTime, endOfFirstPeriod)
535+
}
536+
537+
ruleset, _ := creator.Create()
538+
539+
envelope, _ := solutionCreator.Create(
540+
puan.Selections{
541+
puan.NewSelectionBuilder(item1).Build(),
542+
},
543+
ruleset,
544+
&startTime,
545+
)
546+
547+
solution := envelope.Solution()
548+
549+
asserter := newSolutionAsserter(solution)
550+
asserter.assertActive(t, item1)
551+
asserter.assertActive(t, "period_0")
552+
asserter.assertInactive(t, otherItems...)
553+
asserter.assertInactive(t, "period_1")
554+
}
555+
556+
type solutionAsserter struct {
557+
puan.Solution
558+
}
559+
560+
func newSolutionAsserter(solution puan.Solution) solutionAsserter {
561+
return solutionAsserter{solution}
562+
}
563+
564+
func (s solutionAsserter) assertActive(t *testing.T, variables ...string) {
565+
solution := s.Extract(variables...)
566+
for variable, value := range solution {
567+
assert.Equal(t, 1, value, "expected %s to be active", variable)
568+
}
569+
}
570+
571+
func (s solutionAsserter) assertInactive(t *testing.T, variables ...string) {
572+
solution := s.Extract(variables...)
573+
for variable, value := range solution {
574+
assert.Equal(t, 0, value, "expected %s to be inactive", variable)
575+
}
576+
}

0 commit comments

Comments
 (0)