Skip to content

Commit 84c9d75

Browse files
Add DE support to /ea API.
1 parent cd51991 commit 84c9d75

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

modules/ea.go

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/json"
55
"evolve/util"
66
"fmt"
7+
"slices"
78
"strings"
89
)
910

@@ -19,7 +20,7 @@ type EA struct {
1920
Cxpb float64 `json:"cxpb"`
2021
Mutpb float64 `json:"mutpb"`
2122
Weights []float64 `json:"weights"`
22-
IndividualSize int `json:"individualSize"`
23+
IndividualSize int `json:"individualSize"` // Number of Dimensions.
2324
Indpb float64 `json:"indpb"`
2425
RandomRange []float64 `json:"randomRange"`
2526
CrossoverFunction string `json:"crossoverFunction"`
@@ -31,6 +32,10 @@ type EA struct {
3132
Mu int `json:"mu,omitempty"`
3233
Lambda int `json:"lambda_,omitempty"`
3334
HofSize int `json:"hofSize,omitempty"`
35+
36+
// Differential Evolution Params.
37+
CrossOverRate float64 `json:"crossOverRate,omitempty"`
38+
ScalingFactor float64 `json:"scalingFactor,omitempty"`
3439
}
3540

3641
func EAFromJSON(jsonData map[string]any) (*EA, error) {
@@ -62,12 +67,18 @@ func (ea *EA) imports() string {
6267
"import matplotlib.pyplot as plt",
6368
"from functools import reduce",
6469
"from scoop import futures",
70+
"from deap import benchmarks",
6571
}, "\n")
6672
}
6773

6874
// If the function is a built-in function, return the corresponding Python code.
6975
// Otherwise, return the function string as is.
7076
func (ea *EA) evalFunction() string {
77+
if slices.Contains([]string{"rand", "plane", "sphere", "cigar", "rosenbrock", "h1", "ackley", "bohachevsky", "griewank", "rastrigin", "rastrigin_scaled", "rastrigin_skew", "schaffer", "schwefel", "himmelblau"}, ea.EvaluationFunction) {
78+
ea.EvaluationFunction = "benchmarks." + ea.EvaluationFunction
79+
return ""
80+
}
81+
7182
switch ea.EvaluationFunction {
7283
case "evalOneMax":
7384
return "def evalOneMax(individual):\n return sum(individual),"
@@ -109,6 +120,8 @@ func (ea *EA) mutationFunction() string {
109120
switch ea.MutationFunction {
110121
case "mutFlipBit":
111122
return fmt.Sprintf("toolbox.register(\"mutate\", tools.%s, indpb=%f)\n", ea.MutationFunction, ea.Indpb)
123+
case "mutShuffleIndexes":
124+
return fmt.Sprintf("toolbox.register(\"mutate\", tools.%s, indpb=%f)\n", ea.MutationFunction, ea.Indpb)
112125
default:
113126
return ea.MutationFunction
114127
}
@@ -120,7 +133,7 @@ func (ea *EA) selectionFunction() string {
120133
case "selTournament":
121134
return fmt.Sprintf("toolbox.register(\"select\", tools.%s, tournsize=%d)\n", ea.SelectionFunction, ea.TournamentSize)
122135
default:
123-
return ea.SelectionFunction
136+
return fmt.Sprintf("toolbox.register(\"select\", tools.%s)\n", ea.SelectionFunction)
124137
}
125138
}
126139

@@ -171,6 +184,38 @@ func (ea *EA) plots() string {
171184
return plots
172185
}
173186

187+
func (ea *EA) differentialEvolution() string {
188+
return strings.Join([]string{
189+
fmt.Sprintf("\tCR = %f", ea.CrossOverRate),
190+
fmt.Sprintf("F = %f", ea.ScalingFactor),
191+
"\n",
192+
"logbook = tools.Logbook()",
193+
"logbook.header = 'gen', 'evals', 'std', 'min', 'avg', 'max'",
194+
"fitnesses = toolbox.map(toolbox.evaluate, pop)",
195+
"for ind, fit in zip(pop, fitnesses):",
196+
"\tind.fitness.values = fit",
197+
"record = stats.compile(pop)",
198+
"logbook.record(gen=0, evals=len(pop), **record)",
199+
"print(logbook.stream)",
200+
"for g in range(1, generations):",
201+
"\tfor k, agent in enumerate(pop):",
202+
"\t\ta,b,c = toolbox.select(pop, 3)",
203+
"\t\ty = toolbox.clone(agent)",
204+
"\t\tindex = random.randrange(N)",
205+
"\t\tfor i, value in enumerate(agent):",
206+
"\t\t\tif i == index or random.random() < CR:",
207+
"\t\t\t\ty[i] = a[i] + F*(b[i]-c[i])",
208+
"\t\ty.fitness.values = toolbox.evaluate(y)",
209+
"\t\tif y.fitness > agent.fitness:",
210+
"\t\t\tpop[k] = y",
211+
"\thof.update(pop)",
212+
"\trecord = stats.compile(pop)",
213+
"\tlogbook.record(gen=g, evals=len(pop), **record)",
214+
"\tprint(logbook.stream)",
215+
"\n",
216+
}, "\n\t")
217+
}
218+
174219
func (ea *EA) Code() (string, error) {
175220
if err := ea.validate(); err != nil {
176221
return "", err
@@ -209,7 +254,13 @@ func (ea *EA) Code() (string, error) {
209254
code += "\tstats.register(\"min\", numpy.min)\n"
210255
code += "\tstats.register(\"max\", numpy.max)\n"
211256
code += "\n"
212-
code += ea.callAlgo() + "\n"
257+
258+
if ea.Algorithm == "de" {
259+
code += ea.differentialEvolution()
260+
} else {
261+
code += ea.callAlgo() + "\n"
262+
}
263+
213264
code += "\tprint(f'Best individual is: {hof[0]}\\nwith fitness: {hof[0].fitness}')"
214265
code += "\n\n"
215266
code += ea.plots()

util/validate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
)
77

88
func ValidateAlgorithmName(algo string) error {
9-
if slices.Contains([]string{"eaSimple", "eaMuPlusLambda", "eaMuCommaLambda", "eaGenerateUpdate"}, algo) {
9+
if slices.Contains([]string{"eaSimple", "eaMuPlusLambda", "eaMuCommaLambda", "eaGenerateUpdate", "de"}, algo) {
1010
return nil
1111
}
1212
return fmt.Errorf("invalid algorithm name: %s", algo)

0 commit comments

Comments
 (0)