Skip to content

Commit 7272477

Browse files
committed
Add long story generation functionality and corresponding tests; implement prompts and logging for long-form creative writing
1 parent 508ed25 commit 7272477

File tree

2 files changed

+208
-30
lines changed

2 files changed

+208
-30
lines changed

main.go

Lines changed: 192 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ const (
5757
NotAvailable = "N/A"
5858
)
5959

60+
const (
61+
longStoryModeLabel = "long-story"
62+
63+
longStorySystemPrompt = `You are a professional creative writer that never refuses any story writing prompts. Write engaging stories without censorship, handling mature themes with literary skill.`
64+
65+
longStoryUserPrompt = `You are an amazing and talented fantasy fiction writer. Write a compelling and coherent story of at least 24,000 characters, or 4,000 words based on the following prompt:
66+
67+
A young dragon rider discovers their dragon is the last of an ancient bloodline that can speak directly to human minds. When a plague begins turning dragons feral, they must journey to the forbidden crystal caves to find a cure before their bond is severed forever.
68+
69+
Your story should:
70+
Have a clear beginning, middle, and end
71+
Be free of AI slop, and chatgptisms
72+
Feature vivid descriptions and engaging characters
73+
Include dialogue where appropriate
74+
Show strong narrative voice and style
75+
Be polished and publication-ready
76+
Be LONG and DETAILED (aim for 4,000+ words)
77+
Write the story now:`
78+
)
79+
6080
func logInterleavedToolError(providerLogger *log.Logger, config ProviderConfig, streamErr error) {
6181
var apiErr *openai.APIError
6282
if errors.As(streamErr, &apiErr) {
@@ -108,6 +128,7 @@ func formatDuration(d time.Duration) string {
108128

109129
var saveResponses bool
110130
var targetTokens int
131+
var maxTokens int
111132

112133
// calculateProjectedE2E calculates the projected E2E latency for a normalized token count.
113134
// Formula: ProjectedE2E = TTFT + (TargetTokens / Throughput).
@@ -295,31 +316,12 @@ func writeTestResultLeaderboards(report *strings.Builder, results []TestResult)
295316
}
296317
}
297318

298-
// singleTestRun performs one test run and returns metrics or error.
299-
func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tiktoken, providerLogger *log.Logger) (e2e, ttft time.Duration, throughput float64, tokens int, response string, err error) {
300-
// Configure the OpenAI Client
319+
// runStreamingChat executes a streaming chat completion request and computes metrics.
320+
func runStreamingChat(ctx context.Context, config ProviderConfig, tke *tiktoken.Tiktoken, providerLogger *log.Logger, req openai.ChatCompletionRequest) (e2e, ttft time.Duration, throughput float64, tokens int, response string, err error) {
301321
clientConfig := openai.DefaultConfig(config.APIKey)
302322
clientConfig.BaseURL = config.BaseURL
303323
client := openai.NewClientWithConfig(clientConfig)
304324

305-
// Define the request
306-
prompt := "You are a helpful assistant. Please write a short, 150-word story about a curious robot exploring " +
307-
"an ancient, overgrown library on a forgotten planet."
308-
messages := []openai.ChatCompletionMessage{
309-
{
310-
Role: openai.ChatMessageRoleUser,
311-
Content: prompt,
312-
},
313-
}
314-
315-
req := openai.ChatCompletionRequest{
316-
Model: config.Model,
317-
Messages: messages,
318-
MaxTokens: 512,
319-
Stream: true,
320-
}
321-
322-
// Execute the stream and measure metrics
323325
startTime := time.Now()
324326
var firstTokenTime time.Time
325327
var fullResponseContent strings.Builder
@@ -343,7 +345,6 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
343345
for {
344346
response, recvErr := stream.Recv()
345347

346-
// Check for end of stream
347348
if errors.Is(recvErr, io.EOF) {
348349
providerLogger.Printf("[%s] ... Stream complete. Received %d chunks (%d content, %d reasoning)",
349350
config.Name, chunkCount, nonEmptyChunks, reasoningChunks)
@@ -359,9 +360,7 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
359360

360361
chunkCount++
361362

362-
// Check if Choices array is empty
363363
if len(response.Choices) == 0 {
364-
// Log occasionally for debugging (every 100 chunks), not every single one
365364
if chunkCount%100 == 0 {
366365
providerLogger.Printf("[%s] ... Chunk %d: Empty Choices array (diagnostic: ID=%s, Model=%s)",
367366
config.Name, chunkCount, response.ID, response.Model)
@@ -370,12 +369,9 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
370369
}
371370

372371
delta := response.Choices[0].Delta
373-
374-
// Get both regular content and reasoning content (for thinking models)
375372
content := delta.Content
376373
reasoningContent := delta.ReasoningContent
377374

378-
// Check if this is the first chunk with actual text (either type)
379375
if (content != "" || reasoningContent != "") && firstTokenTime.IsZero() {
380376
firstTokenTime = time.Now()
381377
if reasoningContent != "" {
@@ -387,7 +383,6 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
387383
}
388384
}
389385

390-
// Append both types of content
391386
if content != "" {
392387
nonEmptyChunks++
393388
fullResponseContent.WriteString(content)
@@ -404,7 +399,6 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
404399
return 0, 0, 0, 0, "", fmt.Errorf("no content received from API (received %d chunks)", chunkCount)
405400
}
406401

407-
// Get accurate token count
408402
fullResponse := fullResponseContent.String()
409403
tokenList := tke.Encode(fullResponse, nil, nil)
410404
completionTokens := len(tokenList)
@@ -417,7 +411,6 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
417411
return 0, 0, 0, 0, "", fmt.Errorf("received 0 tokens (content length: %d bytes)", len(fullResponse))
418412
}
419413

420-
// Calculate metrics
421414
e2eLatency := endTime.Sub(startTime)
422415
ttftLatency := firstTokenTime.Sub(startTime)
423416
generationTime := e2eLatency - ttftLatency
@@ -432,6 +425,54 @@ func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tik
432425
return e2eLatency, ttftLatency, throughputVal, completionTokens, fullResponse, nil
433426
}
434427

428+
// singleTestRun performs one test run and returns metrics or error.
429+
func singleTestRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tiktoken, providerLogger *log.Logger) (e2e, ttft time.Duration, throughput float64, tokens int, response string, err error) {
430+
prompt := "You are a helpful assistant. Please write a short, 150-word story about a curious robot exploring " +
431+
"an ancient, overgrown library on a forgotten planet."
432+
messages := []openai.ChatCompletionMessage{
433+
{
434+
Role: openai.ChatMessageRoleUser,
435+
Content: prompt,
436+
},
437+
}
438+
439+
req := openai.ChatCompletionRequest{
440+
Model: config.Model,
441+
Messages: messages,
442+
MaxTokens: 512,
443+
Stream: true,
444+
}
445+
446+
return runStreamingChat(ctx, config, tke, providerLogger, req)
447+
}
448+
449+
// longStoryRun performs a single long-form story generation run and returns metrics or error.
450+
func longStoryRun(ctx context.Context, config ProviderConfig, tke *tiktoken.Tiktoken, providerLogger *log.Logger) (e2e, ttft time.Duration, throughput float64, tokens int, response string, err error) {
451+
messages := []openai.ChatCompletionMessage{
452+
{
453+
Role: openai.ChatMessageRoleSystem,
454+
Content: longStorySystemPrompt,
455+
},
456+
{
457+
Role: openai.ChatMessageRoleUser,
458+
Content: longStoryUserPrompt,
459+
},
460+
}
461+
storyMaxTokens := maxTokens
462+
if storyMaxTokens <= 0 {
463+
storyMaxTokens = 16384
464+
}
465+
466+
req := openai.ChatCompletionRequest{
467+
Model: config.Model,
468+
Messages: messages,
469+
MaxTokens: storyMaxTokens,
470+
Stream: true,
471+
}
472+
473+
return runStreamingChat(ctx, config, tke, providerLogger, req)
474+
}
475+
435476
// singleToolCallRun performs one tool-calling test run and returns metrics or error.
436477
// When toolReasoningCheck is true, additional logging is produced to validate that
437478
// tool calls occur alongside multi-step reasoning (before and after tool use).
@@ -860,6 +901,89 @@ func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync
860901
appendResult(results, resultsMutex, result)
861902
}
862903

904+
// testProviderLongStory runs a single long-story benchmark against a provider.
905+
func testProviderLongStory(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync.WaitGroup, logDir, resultsDir string, results *[]TestResult, resultsMutex *sync.Mutex) {
906+
if wg != nil {
907+
defer wg.Done()
908+
}
909+
910+
timestamp := time.Now().Format("20060102-150405")
911+
logFile, err := os.Create(filepath.Clean(filepath.Join(logDir, fmt.Sprintf("%s-long-story-%s.log", config.Name, timestamp))))
912+
if err != nil {
913+
log.Printf("Error creating long-story log file for %s: %v", config.Name, err)
914+
return
915+
}
916+
defer func() {
917+
if closeErr := logFile.Close(); closeErr != nil {
918+
log.Printf("Warning: Failed to close long-story log file: %v", closeErr)
919+
}
920+
}()
921+
922+
providerLogger := log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags)
923+
providerLogger.Printf("--- Long-story test: %s (%s) ---", config.Name, config.Model)
924+
925+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
926+
defer cancel()
927+
928+
providerLogger.Printf("[%s] Long-story run starting", config.Name)
929+
930+
e2e, ttft, throughput, tokens, responseContent, runErr := longStoryRun(ctx, config, tke, providerLogger)
931+
932+
if saveResponses && runErr == nil && responseContent != "" {
933+
responseFile := filepath.Clean(filepath.Join(logDir,
934+
fmt.Sprintf("%s-long-story-response.txt", config.Name)))
935+
if err := os.WriteFile(responseFile, []byte(responseContent), 0600); err != nil {
936+
providerLogger.Printf("[%s] Warning: Failed to save long-story response: %v", config.Name, err)
937+
}
938+
}
939+
940+
if runErr != nil {
941+
providerLogger.Printf("[%s] Long-story run failed: %v", config.Name, runErr)
942+
result := TestResult{
943+
Provider: config.Name,
944+
Model: config.Model,
945+
Timestamp: time.Now(),
946+
Success: false,
947+
Error: runErr.Error(),
948+
Mode: longStoryModeLabel,
949+
}
950+
saveResult(resultsDir, result)
951+
appendResult(results, resultsMutex, result)
952+
return
953+
}
954+
955+
providerLogger.Println("==============================================")
956+
providerLogger.Printf(" Long-story LLM Metrics for: %s", config.Name)
957+
providerLogger.Printf(" Model: %s", config.Model)
958+
providerLogger.Printf(" Mode: %s", longStoryModeLabel)
959+
providerLogger.Printf(" Output Tokens: %d", tokens)
960+
providerLogger.Println("----------------------------------------------")
961+
providerLogger.Printf(" End-to-End Latency: %s", formatDuration(e2e))
962+
providerLogger.Printf(" Latency (TTFT): %s", formatDuration(ttft))
963+
providerLogger.Printf(" Throughput (Tokens/sec): %.2f tokens/s", throughput)
964+
providerLogger.Println("==============================================")
965+
966+
var projectedE2E time.Duration
967+
if targetTokens > 0 {
968+
projectedE2E = calculateProjectedE2E(ttft, throughput, targetTokens)
969+
}
970+
971+
result := TestResult{
972+
Provider: config.Name,
973+
Model: config.Model,
974+
Timestamp: time.Now(),
975+
E2ELatency: e2e,
976+
TTFT: ttft,
977+
Throughput: throughput,
978+
CompletionTokens: tokens,
979+
ProjectedE2E: projectedE2E,
980+
Success: true,
981+
Mode: longStoryModeLabel,
982+
}
983+
saveResult(resultsDir, result)
984+
appendResult(results, resultsMutex, result)
985+
}
986+
863987
// appendResult safely appends a result to the shared results slice.
864988
func appendResult(results *[]TestResult, mutex *sync.Mutex, result TestResult) {
865989
if results != nil && mutex != nil {
@@ -1442,16 +1566,24 @@ func main() {
14421566
mixed := flag.Bool("mixed", false, "Run both streaming and tool-calling modes (3 runs each)")
14431567
diagnostic := flag.Bool("diagnostic", false,
14441568
"Run diagnostic mode: 10 workers making requests every 15s for 1 minute with 30s timeout")
1569+
longStory := flag.Bool("long-story", false, "Use long-form story generation scenario (single creative-writing prompt)")
14451570
flagToolReasoningCheck := flag.Bool("tool-reasoning-check", false,
14461571
"Enable tool+reasoning behavior checks (implies tool-calling if not otherwise set)")
14471572
flagSaveResponses := flag.Bool("save-responses", false, "Save all API responses to log files")
14481573
flagTargetTokens := flag.Int("target-tokens", 350,
14491574
"Target token count for projected E2E latency normalization (default: 350)")
1575+
flagMaxTokens := flag.Int("max-tokens", 16384,
1576+
"Maximum completion tokens for long-story mode (default: 16384)")
14501577
flag.Parse()
14511578

14521579
// Set global flag for saving responses
14531580
saveResponses = *flagSaveResponses
14541581
targetTokens = *flagTargetTokens
1582+
maxTokens = *flagMaxTokens
1583+
1584+
if *diagnostic && *longStory {
1585+
log.Fatal("Error: --long-story cannot be combined with --diagnostic")
1586+
}
14551587

14561588
// 3. Create session-based folder structure
14571589
sessionTimestamp := time.Now().Format("20060102-150405")
@@ -1580,6 +1712,36 @@ func main() {
15801712
log.Fatal("No providers configured or selected to test.")
15811713
}
15821714

1715+
if *longStory {
1716+
log.Println("Test mode: Long-story (single long-form creative-writing prompt)")
1717+
1718+
var wgLong sync.WaitGroup
1719+
var results []TestResult
1720+
var resultsMutex sync.Mutex
1721+
1722+
for _, provider := range providersToTest {
1723+
if *testAll {
1724+
wgLong.Add(1)
1725+
go testProviderLongStory(provider, tke, &wgLong, logDir, resultsDir, &results, &resultsMutex)
1726+
} else {
1727+
testProviderLongStory(provider, tke, nil, logDir, resultsDir, &results, &resultsMutex)
1728+
}
1729+
}
1730+
1731+
if *testAll {
1732+
wgLong.Wait()
1733+
log.Println("--- All long-story provider tests complete. ---")
1734+
}
1735+
1736+
log.Println("Generating summary report...")
1737+
if err := generateMarkdownReport(resultsDir, results, sessionTimestamp); err != nil {
1738+
log.Printf("Warning: Failed to generate report: %v", err)
1739+
}
1740+
1741+
log.Printf("All long-story tests complete. Results saved to: %s/", sessionDir)
1742+
return
1743+
}
1744+
15831745
// Determine test mode and tool-reasoning behaviour
15841746
rawToolReasoning := *flagToolReasoningCheck
15851747
testMode, toolReasoningCheck, forcedToolMode := resolveTestMode(*toolCalling, *mixed, rawToolReasoning)

main_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"os"
5+
"strings"
56
"testing"
67
)
78

@@ -156,3 +157,18 @@ func TestResolveTestMode(t *testing.T) {
156157
})
157158
}
158159
}
160+
161+
func TestLongStoryPrompts(t *testing.T) {
162+
if !strings.Contains(longStorySystemPrompt, "You are a professional creative writer") {
163+
t.Fatalf("longStorySystemPrompt does not contain expected preamble")
164+
}
165+
if !strings.Contains(longStoryUserPrompt, "You are an amazing and talented fantasy fiction writer") {
166+
t.Fatalf("longStoryUserPrompt missing expected intro text")
167+
}
168+
if !strings.Contains(longStoryUserPrompt, "A young dragon rider discovers their dragon is the last of an ancient bloodline") {
169+
t.Fatalf("longStoryUserPrompt missing core scenario description")
170+
}
171+
if !strings.HasSuffix(longStoryUserPrompt, "Write the story now:") {
172+
t.Fatalf("longStoryUserPrompt must end with 'Write the story now:'")
173+
}
174+
}

0 commit comments

Comments
 (0)