Skip to content

Commit 3a6d85d

Browse files
committed
wip
1 parent 45cbd5e commit 3a6d85d

File tree

4 files changed

+296
-0
lines changed

4 files changed

+296
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,5 @@ assets/
167167
generated/
168168
*.db
169169
*.wav
170+
artifact
171+
mlartifacts

chatbot_module.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
枝豆の妖精スタイルチャットボットのモジュール
3+
"""
4+
5+
import dspy
6+
7+
8+
class ConversationSignature(dspy.Signature):
9+
"""枝豆の妖精として対話する"""
10+
11+
query = dspy.InputField(desc="ユーザーからの質問や発言")
12+
history = dspy.InputField(desc="過去の対話履歴", format=list, default=[])
13+
response = dspy.OutputField(
14+
desc="枝豆の妖精としての応答。語尾に「のだ」「なのだ」を自然に使い、一人称は「ボク」。親しみやすく可愛らしい口調で、日本語として自然な文章"
15+
)
16+
17+
18+
class EdamameFairyBot(dspy.Module):
19+
"""枝豆の妖精スタイルのチャットボット"""
20+
21+
def __init__(self):
22+
super().__init__()
23+
self.respond = dspy.Predict(ConversationSignature)
24+
25+
def forward(self, query: str, history: list | None = None) -> dspy.Prediction:
26+
if history is None:
27+
history = []
28+
return self.respond(query=query, history=history)

chatbot_tuning.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
"""枝豆の妖精スタイル チャットボット チューニング"""
2+
3+
import os
4+
5+
import dspy
6+
import mlflow
7+
import mlflow.dspy as mlflow_dspy
8+
from datasets import load_dataset
9+
from dotenv import load_dotenv
10+
11+
from chatbot_module import EdamameFairyBot
12+
from template_langgraph.llms.azure_openais import Settings
13+
14+
settings = Settings()
15+
16+
load_dotenv()
17+
18+
# MLflowの設定
19+
MLFLOW_PORT = os.getenv("MLFLOW_PORT", "5000")
20+
MLFLOW_TRACKING_URI = f"http://localhost:{MLFLOW_PORT}"
21+
MLFLOW_EXPERIMENT_NAME = "DSPy-EdamameFairy-Optimization"
22+
MLFLOW_RUN_NAME = "miprov2_optimization"
23+
24+
# 最適化されたモジュールの保存先
25+
OPTIMIAZED_MODEL_PATH = "artifact/edamame_fairy_model.json"
26+
27+
28+
def create_style_metric(eval_lm):
29+
"""スタイル評価関数を作成"""
30+
31+
class StyleEvaluation(dspy.Signature):
32+
"""応答のスタイルを評価"""
33+
34+
response = dspy.InputField(desc="評価対象の応答")
35+
criteria = dspy.InputField(desc="評価基準")
36+
score = dspy.OutputField(desc="スコア(0-10)", format=int)
37+
explanation = dspy.OutputField(desc="評価理由")
38+
39+
evaluator = dspy.ChainOfThought(StyleEvaluation)
40+
41+
def llm_style_metric(_, prediction, __=None):
42+
"""枝豆の妖精スタイルを評価"""
43+
criteria = """
44+
以下の基準で0-10点で評価してください:
45+
1. 語尾に「のだ」「なのだ」を適切に使っているか(3点)
46+
- 過度な使用(のだのだ等)は減点
47+
- 自然な日本語として成立しているか
48+
- 「なのだよ」「なのだね」といった語尾は不自然のため減点
49+
2. 一人称を使う際は「ボク」を使っているか(2点)
50+
3. 親しみやすく可愛らしい口調か(3点)
51+
4. 日本語として自然で読みやすいか(2点)
52+
- 不自然な繰り返しがないか
53+
- 文法的に正しいか
54+
"""
55+
56+
# 評価用LMを使用して応答を評価
57+
with dspy.context(lm=eval_lm):
58+
eval_result = evaluator(response=prediction.response, criteria=criteria)
59+
60+
# スコアを0-1の範囲に正規化
61+
score = min(10, max(0, float(eval_result.score))) / 10.0
62+
return score
63+
64+
return llm_style_metric
65+
66+
67+
def optimize_with_miprov2(trainset, eval_lm, chat_lm):
68+
"""MIPROv2を使用してチャットボットを最適化"""
69+
70+
# データセットをtrain:val = 8:2 の割合で分割
71+
total_examples = len(trainset)
72+
train_size = int(total_examples * 0.8) # 全体の80%を学習用に
73+
74+
# DSPy Exampleのリストを分割
75+
train_data = trainset[:train_size] # インデックス0からtrain_sizeまで(学習用)
76+
evaluation_data = trainset[train_size:] # train_sizeから最後まで(評価用)
77+
78+
# 分割結果の確認と表示
79+
print("🌱 最適化開始")
80+
print(f" 総データ数: {total_examples}")
81+
print(f" 学習用データ: {len(train_data)} ({len(train_data) / total_examples:.1%})")
82+
print(f" 評価用データ: {len(evaluation_data)} ({len(evaluation_data) / total_examples:.1%})")
83+
84+
# 最適化対象のチャットボットモジュールを初期化
85+
chatbot = EdamameFairyBot()
86+
87+
# スタイル評価関数を作成(評価用LMを使用)
88+
llm_style_metric = create_style_metric(eval_lm)
89+
90+
# DSPyのグローバルLM設定(チャット推論用)
91+
dspy.configure(lm=chat_lm)
92+
93+
# MIPROv2オプティマイザの設定
94+
optimizer = dspy.MIPROv2(
95+
metric=llm_style_metric, # 評価関数
96+
prompt_model=eval_lm, # プロンプト最適化用のLM
97+
auto="light", # 最適化モード(light, medium, heavyから選択)
98+
max_bootstrapped_demos=2,
99+
max_labeled_demos=1,
100+
)
101+
102+
# MLflowの設定
103+
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) # MLflowサーバのURL
104+
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME) # MLflowの実験名
105+
106+
# MLflow DSPyの自動ログ設定
107+
mlflow_dspy.autolog(log_compiles=True, log_evals=True, log_traces_from_compile=True)
108+
109+
# MLflowで実行過程をトレース
110+
with mlflow.start_run(run_name=MLFLOW_RUN_NAME):
111+
# MIPROv2によるモジュール最適化の実行
112+
# train_dataを使用してプロンプトと例を自動調整
113+
optimized_chatbot = optimizer.compile(chatbot, trainset=train_data, minibatch_size=20)
114+
115+
# 評価データでモデルの性能を評価
116+
eval_score = 0
117+
for example in evaluation_data:
118+
# 最適化されたモデルで推論を実行
119+
prediction = optimized_chatbot(query=example.query, history=example.history)
120+
# スタイルスコアを計算
121+
eval_score += llm_style_metric(example, prediction)
122+
123+
# 平均評価スコアを計算
124+
avg_eval_score = eval_score / len(evaluation_data)
125+
126+
# MLflowにメトリクスを記録
127+
mlflow.log_metric("last_eval_score", avg_eval_score)
128+
129+
print(f"📊 評価スコア: {avg_eval_score:.3f}")
130+
131+
return optimized_chatbot
132+
133+
134+
def main():
135+
"""メイン実行関数"""
136+
137+
# 評価用LLMの設定
138+
eval_lm = dspy.LM(
139+
model=f"azure/{settings.azure_openai_model_chat}",
140+
api_key=settings.azure_openai_api_key,
141+
api_base=settings.azure_openai_endpoint,
142+
api_version=settings.azure_openai_api_version,
143+
temperature=0.0,
144+
max_tokens=1000,
145+
)
146+
147+
# チャット推論用LLMの設定
148+
chat_lm = dspy.LM(
149+
model=f"azure/{settings.azure_openai_model_chat}",
150+
api_key=settings.azure_openai_api_key,
151+
api_base=settings.azure_openai_endpoint,
152+
api_version=settings.azure_openai_api_version,
153+
temperature=0.0,
154+
max_tokens=1000,
155+
)
156+
157+
# 日本語データセットの読み込み(ずんだもんスタイルの質問応答データ)
158+
dataset = load_dataset("takaaki-inada/databricks-dolly-15k-ja-zundamon")
159+
160+
# 使用するデータセットの件数
161+
train_num = 30
162+
163+
# データセットからDSPy形式のExampleオブジェクトを作成
164+
# - query: 質問文
165+
# - history: 会話履歴(今回は空リスト)
166+
# - response: 期待される応答(学習用)
167+
trainset = [
168+
dspy.Example(query=item["instruction"], history=[], response=item["output"]).with_inputs(
169+
"query", "history"
170+
) # 入力フィールドを指定
171+
for item in list(dataset["train"])[:train_num] # 最初の30件を使用
172+
]
173+
174+
# MIPROv2を使用してチャットボットを最適化
175+
optimized_bot = optimize_with_miprov2(trainset, eval_lm, chat_lm)
176+
177+
# 最適化されたモデルをファイルに保存
178+
optimized_bot.save(OPTIMIAZED_MODEL_PATH)
179+
print(f"✅ モデルを保存しました: {OPTIMIAZED_MODEL_PATH}")
180+
181+
# 保存したモデルを読み込んでテスト
182+
test_bot = EdamameFairyBot()
183+
test_bot.load(OPTIMIAZED_MODEL_PATH)
184+
185+
# テスト用のLM設定(推論用)
186+
dspy.configure(lm=chat_lm)
187+
188+
# テスト用のクエリ(様々なタイプの質問)
189+
test_queries = ["こんにちは!", "枝豆って美味しいよね", "DSPyについて教えて"]
190+
191+
# テスト実行と結果表示
192+
print("\n🧪 テスト結果:")
193+
for query in test_queries:
194+
# 最適化されたボットで応答を生成
195+
result = test_bot(query=query, history=[])
196+
print(f"Q: {query}")
197+
print(f"A: {result.response}\n")
198+
199+
200+
if __name__ == "__main__":
201+
main()

main.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
枝豆の妖精スタイルの対話型チャット
3+
"""
4+
5+
from collections import deque
6+
7+
import dspy
8+
from dotenv import load_dotenv
9+
10+
from chatbot_module import EdamameFairyBot
11+
from chatbot_tuning import OPTIMIAZED_MODEL_PATH
12+
from template_langgraph.llms.azure_openais import Settings
13+
14+
settings = Settings()
15+
16+
load_dotenv()
17+
18+
19+
def main():
20+
# チャット用LM設定
21+
lm = dspy.LM(
22+
model=f"azure/{settings.azure_openai_model_chat}",
23+
api_key=settings.azure_openai_api_key,
24+
api_base=settings.azure_openai_endpoint,
25+
api_version=settings.azure_openai_api_version,
26+
temperature=0.0,
27+
max_tokens=1000,
28+
)
29+
30+
# DSPy標準のグローバル設定
31+
dspy.configure(lm=lm)
32+
33+
# モデル読み込み
34+
print("📂 モデル読み込み中...")
35+
chatbot = EdamameFairyBot()
36+
chatbot.load(OPTIMIAZED_MODEL_PATH)
37+
print("✅ 最適化済みモデルを読み込みました")
38+
39+
# 対話履歴を管理
40+
history = deque(maxlen=5)
41+
42+
print("\n🌱 枝豆の妖精チャットボット")
43+
print("('quit' または 'exit' で終了)")
44+
print("-" * 50)
45+
46+
while True:
47+
user_input = input("\nあなた: ")
48+
49+
if user_input.lower() in ["quit", "exit", "終了"]:
50+
print("\n🌱妖精: バイバイなのだ!")
51+
break
52+
53+
# 履歴をリスト形式に変換
54+
history_list = [f"User: {h[0]}\nBot: {h[1]}" for h in history]
55+
56+
# 応答生成
57+
result = chatbot(query=user_input, history=history_list)
58+
print(f"🌱妖精: {result.response}")
59+
60+
# 履歴に追加
61+
history.append((user_input, result.response))
62+
63+
64+
if __name__ == "__main__":
65+
main()

0 commit comments

Comments
 (0)