-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
325 lines (271 loc) · 10.8 KB
/
server.py
File metadata and controls
325 lines (271 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
from flask import (
Flask,
request,
render_template,
jsonify,
)
from enum import Enum
import torch
import os
import sys
import re
import tempfile
import json
from typing import List, Dict, Any, Optional
# Add project root to path so imports work when running this file directly
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from src.model.transformer import AlphaMonGPT
from src.data.vocab_manager import VocabManager
from src.data.parser import LogTokenizer
from src.monte_carlo_search import MonteCarloSearcher, TURN, TEAM_PREVIEW, FAINT
app = Flask(__name__)
# Configuration
CHECKPOINT_PATH = "checkpoints/alphamon_finetuned_best.pt"
CONTEXT_LEN = 2048
D_MODEL = 512
N_HEAD = 8
N_LAYERS = 6
DROPOUT = 0.1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- ENUM DEFINITION ---
class Status(Enum):
WAITING = "waiting"
GENERATING = "generating"
READY = "ready"
# Global variables for model
model: Optional[AlphaMonGPT] = None
vm: Optional[VocabManager] = None
searcher: Optional[MonteCarloSearcher] = None
# Global variables for current state
current_status: Status = Status.WAITING
current_best_moves: List[Dict[str, Any]] = []
current_tokenized: str = ""
current_player: int = 1
current_num_pokemons: int = 2
global_objective: str = TURN
# History of past generations
current_history: List[Dict[str, Any]] = []
# simple change-tracker for clients
last_update: float = 0
def color_action(action_str: str, player: int) -> str:
if player == 1:
player_tag = "P1"
opponent_tag = "P2"
else:
player_tag = "P2"
opponent_tag = "P1"
# Simple HTML coloring helper
def color_span(text: str, cls: str) -> str:
return f'<span class="{cls}">{text}</span>'
# Color player targets green, opponent red
# We use regex to find tokens like [P1] or [SIDE_A] if strictly needed,
# but based on current tokens usually looking like [P1] or [P2]
# 1. Highlight Player's Mon
action_str = re.sub(
rf"\[{player_tag}[^\]]*\]",
lambda m: color_span(m.group(0), "player-pokemon"),
action_str,
)
# 2. Highlight Opponent's Mon
action_str = re.sub(
rf"\[{opponent_tag}[^\]]*\]",
lambda m: color_span(m.group(0), "opponent-pokemon"),
action_str,
)
# 3. Make each [CMD_MOVE] a newline in the final HTML output for readability
action_str = action_str.replace("[CMD_MOVE]", "<br/>")
return action_str
def load_model():
global model, vm, searcher, last_update
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")
vm = VocabManager()
vm.load_vocab()
model = AlphaMonGPT(
vocab_size=len(vm),
d_model=D_MODEL,
n_head=N_HEAD,
n_layers=N_LAYERS,
max_len=CONTEXT_LEN,
dropout=DROPOUT,
)
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint)
new_state_dict = {
(k[10:] if k.startswith("_orig_mod.") else k): v for k, v in state_dict.items()
}
model.load_state_dict(new_state_dict)
model.to(DEVICE)
model.eval()
searcher = MonteCarloSearcher(model, vm, device=DEVICE)
print("Model loaded successfully.")
last_update += 1
@app.route("/", methods=["GET", "POST"])
def index():
global current_status, current_best_moves, current_tokenized, current_player, current_num_pokemons, last_update, global_objective, current_history
# Default values for GET
status_text = "Waiting for battle data..."
status_class = Status.WAITING.value
# persist last selected objective across requests so GET (page reload) shows correct heading
objective = global_objective
if request.method == "POST":
# 1. Archive current results to history if they exist
if current_best_moves:
history_item = {
"moves": current_best_moves,
"tokenized": current_tokenized,
"objective": global_objective,
"timestamp": last_update,
}
current_history.insert(0, history_item)
# Keep only last 10
if len(current_history) > 10:
current_history.pop()
# 2. Reset for new generation
current_status = Status.GENERATING
last_update += 1
current_best_moves = [] # Clear for now
log_text = request.form["log"]
current_player = int(request.form["player"])
current_num_pokemons = int(request.form.get("num_pokemons", 2))
objective = request.form.get("objective", TURN)
# Parse team data for logit bias if provided
team_data = None
if "team_data" in request.form:
try:
team_data = json.loads(request.form["team_data"])
print(
f"Received team data for {len(team_data)} Pokémon to use as logit bias."
)
except json.JSONDecodeError:
print("Failed to parse team_data JSON.")
# remember the chosen objective so subsequent GET (page reloads) render correctly
global_objective = objective
# Tokenize the log
if vm is None:
return jsonify({"error": "Model not loaded"}), 500
tokenizer = LogTokenizer(vm)
# Save log to temp file for tokenization
with tempfile.NamedTemporaryFile(mode="w", suffix=".log", delete=False) as tmp:
tmp.write(log_text)
tmp_path = tmp.name
try:
tokens = tokenizer.tokenize_replay(tmp_path)
# Remove EOS if present
if tokens and tokens[-1] == vm.vocab[vm.EOS_TOKEN]:
tokens = tokens[:-1]
finally:
os.unlink(tmp_path)
current_tokenized = (
" ".join([vm.decode(t) for t in tokens[-50:]]) + f" (len={len(tokens)})"
) # Last 50 tokens for display
# Dynamic rollouts: tuned for RTX 3090 (24 GB)
base_rollouts = 512
context_penalty = len(tokens) // 400 # gentle reduction at very long contexts
# Setup parameters based on objective
if objective == TEAM_PREVIEW:
rollout_len = 128 # exhaustive eval, rollout_len unused
num_rollouts = max(64, base_rollouts)
elif objective == FAINT:
rollout_len = 128 # exhaustive eval, rollout_len unused
num_rollouts = max(32, base_rollouts - context_penalty * 2)
else: # TURN
rollout_len = 128
num_rollouts = max(64, base_rollouts - context_penalty * 2)
# Build logit bias tensor if team data is provided
logit_bias = None
if team_data and vm is not None:
# Initialize bias tensor with zeros
logit_bias = torch.zeros((1, len(vm)), device=DEVICE)
bias_value = 2.0 # Gentle nudge toward known team moves
for species, data in team_data.items():
# Boost the species token
mon_token = vm.formatted_token("MON", data["species"])
if mon_token in vm.vocab:
logit_bias[0, vm.vocab[mon_token]] += bias_value
# Boost the moves
for move in data.get("moves", []):
move_token = vm.formatted_token("MOVE", move)
if move_token in vm.vocab:
logit_bias[0, vm.vocab[move_token]] += bias_value
# Boost the item
if data.get("item"):
item_token = vm.formatted_token("ITEM", data["item"])
if item_token in vm.vocab:
logit_bias[0, vm.vocab[item_token]] += bias_value
# Boost the tera type
if data.get("tera_type"):
tera_token = vm.formatted_token("TYPE", data["tera_type"])
if tera_token in vm.vocab:
logit_bias[0, vm.vocab[tera_token]] += bias_value
print("Applied logit bias based on team data.")
# Run MCS
print(f"Starting Search for {objective}...")
if searcher is None:
return jsonify({"error": "Model not loaded"}), 500
current_best_moves = searcher.search_move(
tokens,
num_rollouts=num_rollouts,
rollout_len=rollout_len, # Longer for multiple actions
temperature=0.7,
as_player=current_player,
objective=objective,
logit_bias=logit_bias,
)
# Free GPU memory immediately — long battles accumulate context fast
if logit_bias is not None:
del logit_bias
logit_bias = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Top moves found:")
for m in current_best_moves[:5]:
print(
f" Action: {m.get('action', '?')}, "
f"Avg: {m['avg_score']:.4f}, LCB: {m.get('lcb_score', 0):.4f}, "
f"Count: {m['count']}"
)
# The MCS now returns pre-formatted HTML in 'action'
for m in current_best_moves:
m["colored_action"] = m.get("action", "Unknown Action")
current_status = Status.READY
last_update += 1
# If the request asks for JSON (like the auto_server does), return JSON
if request.headers.get("Accept") == "application/json":
return jsonify(
{
"status": "success",
"objective": objective,
"best_moves": current_best_moves,
}
)
# Determine status text and class for display (updated state)
if current_status == Status.WAITING:
status_text = "Waiting for battle data..."
status_class = Status.WAITING.value
elif current_status == Status.GENERATING:
status_text = "Generating predictions..."
status_class = Status.GENERATING.value
elif current_status == Status.READY:
status_text = "Predictions ready!"
status_class = Status.READY.value
return render_template(
"index.html",
status_text=status_text,
status_class=status_class,
current_best_moves=current_best_moves,
current_history=current_history,
current_tokenized=current_tokenized,
objective=objective,
last_update=last_update,
)
@app.route("/_status", methods=["GET"])
def poll_status():
# lightweight endpoint for clients to check if content changed
# We use .value to return the string ('ready') instead of the Enum object
return jsonify({"last_update": last_update, "status": current_status.value})
if __name__ == "__main__":
load_model()
app.run(debug=True, host="0.0.0.0", port=5000)