Skip to content

Commit 354060e

Browse files
committed
added short term plans, fixed tests, made activities more realistic
1 parent 671defd commit 354060e

File tree

14 files changed

+162
-36
lines changed

14 files changed

+162
-36
lines changed

smallville/src/main/java/io/github/nickm980/smallville/config/PromptsConfig.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public class PromptsConfig {
1111
private String createPastAndPresent;
1212
private String createObjectUpdates;
1313
private String pickLocation;
14+
private String createShortTermPlans;
1415

1516
public PromptsConfig() {
1617
}
@@ -87,4 +88,12 @@ public void setPickLocation(String pickLocation) {
8788
this.pickLocation= pickLocation;
8889
}
8990

91+
public String getCreateShortTermPlans() {
92+
return createShortTermPlans;
93+
}
94+
95+
public void setCreateShortTermPlans(String createShortTermPlans) {
96+
this.createShortTermPlans = createShortTermPlans;
97+
}
98+
9099
}

smallville/src/main/java/io/github/nickm980/smallville/llm/update/ChatService.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,23 @@ public String ask(Agent agent, String question) {
113113
return chat.sendChat(prompt, .9);
114114
}
115115

116-
public List<Plan> getPlans(Agent agent, TimePhrase phrase) {
116+
public List<Plan> getPlans(Agent agent) {
117117
Prompt prompt = new PromptBuilder()
118118
.withLocations(world.getLocations())
119119
.withAgent(agent)
120-
.createFuturePlansPrompt(phrase)
120+
.createFuturePlansPrompt()
121+
.build();
122+
123+
String response = chat.sendChat(prompt, .7);
124+
125+
return parsePlans(response);
126+
}
127+
128+
public List<Plan> getShortTermPlans(Agent agent) {
129+
Prompt prompt = new PromptBuilder()
130+
.withLocations(world.getLocations())
131+
.withAgent(agent)
132+
.createShortTermPlansPrompt()
121133
.build();
122134

123135
String response = chat.sendChat(prompt, .7);
@@ -150,9 +162,7 @@ public CurrentPlan getCurrentPlan(Agent agent) {
150162
result.setLastActivity(json.get("last_activity").asText());
151163
result.setLocation(json.get("location").asText());
152164

153-
LOG
154-
.info(agent.getFullName() + ": " + result.getCurrentActivity() + " emoji: " + result.getEmoji()
155-
+ " location: " + agent.getLocation());
165+
LOG.info("[Activity]" + result.getCurrentActivity() + " location: " + agent.getLocation().getName());
156166

157167
return result;
158168
}

smallville/src/main/java/io/github/nickm980/smallville/llm/update/UpdateCurrentActivity.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class UpdateCurrentActivity extends AgentUpdate {
1010

1111
@Override
1212
public boolean update(ChatService service, World world, Agent agent) {
13-
LOG.info("[Updater / Activity] Updating current activity and emoji");
13+
LOG.info("[Activity] Updating current activity and emoji");
1414

1515
CurrentPlan plan = service.getCurrentPlan(agent);
1616
SimulatedLocation location = world.getLocation(plan.getLocation()).orElse(agent.getLocation());

smallville/src/main/java/io/github/nickm980/smallville/llm/update/UpdateFuturePlans.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,27 @@ public class UpdateFuturePlans extends AgentUpdate {
1212
@Override
1313
public boolean update(ChatService converter, World world, Agent agent) {
1414
LOG.info("[Plans] Updating future plans");
15+
agent.getMemoryStream().prunePlans();
1516

1617
if (agent.getPlans().isEmpty() || agent.getPlans().size() < 5) {
17-
List<Plan> future = converter.getPlans(agent, TimePhrase.DAY);
18-
agent.setPlans(future);
18+
List<Plan> plans = converter.getPlans(agent);
19+
agent.setPlans(plans);
1920
}
2021

21-
// TODO: iterate for finer grain plans that are closer to the present
22+
if (agent.getMemoryStream().getShortTermPlans().size() < 5) {
23+
List<Plan> plans = converter.getShortTermPlans(agent);
2224

23-
for (Plan plan : agent.getPlans()) {
24-
LOG.info("[Plans] " + agent.getFullName() + ": " + plan.asNaturalLanguage());
25+
for (Plan plan : plans) {
26+
plan.convertToShortTermMemory(true);
27+
}
28+
29+
agent.setShortTermPlans(plans);
30+
}
31+
32+
for (Plan plan : agent.getMemoryStream().sortByTime(agent.getPlans()).stream().map(m -> (Plan) m).toList()) {
33+
LOG.info("[Plans] " + plan.asNaturalLanguage() + " short term: " + plan.isShortTerm());
2534
}
2635

2736
return next(converter, world, agent);
2837
}
29-
3038
}

smallville/src/main/java/io/github/nickm980/smallville/llm/update/UpdateLocations.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ public class UpdateLocations extends AgentUpdate {
88

99
@Override
1010
public boolean update(ChatService converter, World world, Agent agent) {
11+
LOG.info("[Locations] Updating location states");
12+
1113
ObjectChangeResponse[] objects = converter.getObjectsChangedBy(agent);
1214

1315
if (objects.length > 0) {
@@ -16,6 +18,7 @@ public boolean update(ChatService converter, World world, Agent agent) {
1618
}
1719
}
1820

21+
LOG.info("[Locations] Location states updated");
1922
return next(converter, world, agent);
2023
}
2124
}

smallville/src/main/java/io/github/nickm980/smallville/models/Agent.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,8 @@ public void setPlans(List<Plan> plans) {
7171
public List<Plan> getPlans() {
7272
return memories.getPlans();
7373
}
74+
75+
public void setShortTermPlans(List<Plan> plans) {
76+
this.getMemoryStream().setShortTermPlans(plans);
77+
}
7478
}

smallville/src/main/java/io/github/nickm980/smallville/models/memory/MemoryStream.java

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package io.github.nickm980.smallville.models.memory;
22

33
import java.time.LocalDateTime;
4+
import java.util.ArrayList;
5+
import java.util.Comparator;
46
import java.util.HashSet;
57
import java.util.List;
68
import java.util.Set;
@@ -10,10 +12,10 @@
1012
* Includes plans, observations, and characteristics
1113
*/
1214
public class MemoryStream {
13-
private Set<Memory> memories;
15+
private List<Memory> memories;
1416

1517
public MemoryStream() {
16-
this.memories = new HashSet<Memory>();
18+
this.memories = new ArrayList<Memory>();
1719
}
1820

1921
/**
@@ -38,7 +40,7 @@ public void remember(String memory) {
3840
this.memories.add(new Observation(memory));
3941
}
4042

41-
public Set<Memory> getMemories() {
43+
public List<Memory> getMemories() {
4244
return memories;
4345
}
4446

@@ -83,13 +85,33 @@ public void addCharacteristics(List<Characteristic> characteristics) {
8385
}
8486

8587
public void prunePlans() {
86-
for (Memory memory : memories) {
88+
memories.removeIf((memory) -> {
8789
if (memory instanceof Plan) {
8890
Plan plan = (Plan) memory;
8991
if (plan.getTime() != null && plan.getTime().compareTo(LocalDateTime.now()) < 0) {
90-
memories.remove(plan);
92+
return true;
9193
}
9294
}
93-
}
95+
return false;
96+
});
97+
}
98+
99+
public void setShortTermPlans(List<Plan> plans) {
100+
List<Plan> removed = getShortTermPlans();
101+
memories.removeAll(removed);
102+
memories.addAll(plans);
103+
}
104+
105+
public List<Plan> getShortTermPlans() {
106+
return getPlans().stream().filter(plan -> plan.isShortTerm()).toList();
107+
}
108+
109+
public List<? extends TemporalMemory> sortByTime(List<? extends TemporalMemory> mems) {
110+
return mems.stream().sorted(new Comparator<TemporalMemory>() {
111+
@Override
112+
public int compare(TemporalMemory o1, TemporalMemory o2) {
113+
return o1.getTime().compareTo(o2.getTime());
114+
}
115+
}).toList();
94116
}
95117
}

smallville/src/main/java/io/github/nickm980/smallville/models/memory/Plan.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
package io.github.nickm980.smallville.models.memory;
22

3-
import java.time.Duration;
43
import java.time.LocalDateTime;
54
import java.time.temporal.ChronoUnit;
65

76
import io.github.nickm980.smallville.math.SmallvilleMath;
87
import io.github.nickm980.smallville.models.AccessTime;
9-
import io.github.nickm980.smallville.models.SimulatedLocation;
108
import io.github.nickm980.smallville.models.NaturalLanguageConvertible;
119

1210
/**
@@ -18,10 +16,20 @@
1816
public class Plan extends Memory implements TemporalMemory, NaturalLanguageConvertible {
1917

2018
private final LocalDateTime time;
19+
private boolean isShortTerm;
2120

2221
public Plan(String description, LocalDateTime time) {
22+
this(description, time, false);
23+
}
24+
25+
public Plan(String description, LocalDateTime time, boolean isShortTerm) {
2326
super(description);
2427
this.time = time;
28+
this.isShortTerm = isShortTerm;
29+
}
30+
31+
public boolean isShortTerm() {
32+
return isShortTerm;
2533
}
2634

2735
public LocalDateTime getTime() {
@@ -42,4 +50,8 @@ public LocalDateTime getTime() {
4250
public String asNaturalLanguage() {
4351
return getDescription();
4452
}
53+
54+
public void convertToShortTermMemory(boolean b) {
55+
this.isShortTerm = b;
56+
}
4557
}

smallville/src/main/java/io/github/nickm980/smallville/prompts/AtomicPromptBuilder.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import java.time.LocalDateTime;
44
import java.time.format.DateTimeFormatter;
55
import java.util.Collection;
6+
import java.util.Comparator;
67
import java.util.List;
78
import java.util.stream.Collectors;
89

@@ -13,6 +14,7 @@
1314
import io.github.nickm980.smallville.models.SimulatedObject;
1415
import io.github.nickm980.smallville.models.memory.Characteristic;
1516
import io.github.nickm980.smallville.models.memory.Memory;
17+
import io.github.nickm980.smallville.models.memory.Plan;
1618

1719
/**
1820
* Creates the variable prompts and converts objects to natural language
@@ -73,7 +75,6 @@ public String getAgentSummaryDescription(Agent person) {
7375
Name: %name%
7476
Description: %description%
7577
Current Location: %location%
76-
%plans%
7778
7879
[Current Time]
7980
""";
@@ -125,4 +126,21 @@ public CharSequence getObjects(List<SimulatedObject> objects) {
125126

126127
return result;
127128
}
129+
130+
public CharSequence getLatestPlan(Agent agent) {
131+
String result = "";
132+
133+
Plan plan = agent.getPlans().stream().sorted(new Comparator<Plan>() {
134+
@Override
135+
public int compare(Plan o1, Plan o2) {
136+
return o1.getTime().compareTo(o2.getTime());
137+
}
138+
}).findFirst().orElse(null);
139+
140+
if (plan == null) {
141+
return result;
142+
}
143+
144+
return "The next plan for the day is: " + plan.asNaturalLanguage();
145+
}
128146
}

smallville/src/main/java/io/github/nickm980/smallville/prompts/IPromptBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,9 @@ public interface IPromptBuilder {
7070
/**
7171
* Adds the future plans of the agent, given a time frame, to the prompt.
7272
*
73-
* @param time the time frame for the future plans
7473
* @return the prompt builder instance
7574
*/
76-
PromptBuilder createFuturePlansPrompt(TimePhrase time);
75+
PromptBuilder createFuturePlansPrompt();
7776

7877
/**
7978
*

0 commit comments

Comments
 (0)