Chain-of-Thought (CoT) prompting is a technique that enables large language models to solve complex reasoning tasks by generating intermediate reasoning steps before producing a final answer. Instead of directly answering a question, the model first articulates its thought process, leading to more accurate and interpretable results.
Key Insight: Complex reasoning emerges when models are prompted to "show their work" - the intermediate steps serve as a scaffold for solving multi-step problems.
Traditional prompting often fails on tasks requiring:
- Multi-step arithmetic
- Logical deduction
- Commonsense reasoning chains
- Symbol manipulation
CoT addresses this by making the reasoning process explicit, allowing the model to:
- Break down complex problems
- Track intermediate results
- Compose multi-step solutions
- Self-correct through explicit reasoning
CoT reasoning emerges in models with sufficient scale (typically >100B parameters). The phenomenon is explained by:
- Compositional Reasoning: Models learn to compose simple reasoning primitives
- In-Context Learning: Few-shot examples demonstrate the reasoning pattern
- Latent Reasoning: Models develop internal representations for multi-step processes
Given a problem
Chain-of-Thought introduces intermediate reasoning steps
In practice, we use greedy decoding:
The neural implementation uses layered attention to maintain reasoning state:
For each reasoning step
Where:
-
$Q_i$ = query for current step -
$K_{1:i}, V_{1:i}$ = keys/values from all previous steps - The model attends to prior reasoning when generating new thoughts
┌─────────────┐
│ Problem │
└──────┬──────┘
│
▼
┌──────────────────────┐
│ Step Embedding 1 │
│ ┌────────────────┐ │
│ │ Thought Layer │ │
│ └────────┬───────┘ │
│ │ │
│ ┌────────▼───────┐ │
│ │ Reasoning 1 │ │
│ └────────────────┘ │
└──────┬───────────────┘
│
▼
┌──────────────────────┐
│ Step Embedding 2 │
│ ┌────────────────┐ │
│ │ Thought Layer │ │ ◄── Attends to previous steps
│ └────────┬───────┘ │
│ │ │
│ ┌────────▼───────┐ │
│ │ Reasoning 2 │ │
│ └────────────────┘ │
└──────┬───────────────┘
│
⋮ (N steps)
│
▼
┌──────────────────────┐
│ Final Answer │
└──────────────────────┘
Problem: "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?"
Without CoT: "14" (wrong)
With CoT:
- Step 1: Roger started with 5 balls
- Step 2: He bought 2 cans with 3 balls each
- Step 3: 2 cans × 3 balls/can = 6 balls
- Step 4: Total = 5 + 6 = 11 balls
- Answer: 11 (correct)
Zero-Shot CoT: Simply append "Let's think step by step" to the problem:
Q: What is 15% of 80?
A: Let's think step by step.
Few-Shot CoT: Provide examples with reasoning chains:
Q: John has 3 apples. He buys 2 more. How many does he have?
A: John started with 3 apples. He bought 2 more.
So 3 + 2 = 5 apples.
Q: Sarah has 10 cookies. She eats 3. How many are left?
A: [Model generates reasoning]
The Nexus implementation uses specialized reasoning modules:
class ReasoningStep:
- attention: Multi-head self-attention over current reasoning state
- norm: Layer normalization
- ffn: Feed-forward transformation
forward(x, context):
attended = self.attention(x)
if context is not None:
attended += self.attention(x, context) # Attend to original problem
x = self.norm(x + attended)
x = x + self.ffn(x)
return xKey components:
- Step Embeddings: Distinguish reasoning stages
- Residual Connections: Preserve information across steps
- Context Attention: Maintain focus on original problem
Reference implementation: Nexus/nexus/models/nlp/chain_of_thoughts.py
from nexus.models.nlp.chain_of_thoughts import ChainOfThoughtModule
config = {
"num_reasoning_steps": 4,
"hidden_size": 768,
"vocab_size": 50257
}
cot_module = ChainOfThoughtModule(config)
# Forward pass
outputs = cot_module(
hidden_states=input_embeddings, # (batch_size, seq_len, hidden_size)
attention_mask=mask
)
# Access reasoning steps
logits = outputs["logits"] # Final predictions
reasoning_steps = outputs["reasoning_steps"] # List of intermediate states
attention_maps = outputs["attention_maps"] # Attention patterns per stepfrom nexus.models.nlp.chain_of_thoughts import ReasoningLLM
config = {
"vocab_size": 50257,
"hidden_size": 768,
"max_seq_length": 512,
"num_reasoning_steps": 4
}
model = ReasoningLLM(config)
# Generate with reasoning
outputs = model(
input_ids=input_tokens,
attention_mask=mask,
return_reasoning_steps=True
)
# Inspect reasoning process
for i, step in enumerate(outputs["reasoning_steps"]):
print(f"Step {i}: {step}")class AdaptiveCoT(nn.Module):
"""Dynamically determine reasoning steps based on problem complexity"""
def __init__(self, config):
super().__init__()
self.max_steps = config["max_steps"]
self.complexity_scorer = nn.Linear(hidden_size, 1)
def forward(self, x):
# Predict number of reasoning steps needed
complexity = self.complexity_scorer(x.mean(dim=1))
num_steps = min(max(1, int(complexity.item())), self.max_steps)
# Run adaptive reasoning
for step in range(num_steps):
x = self.reasoning_step(x)
return xInitialize step embeddings with positional encoding patterns:
step_embeddings = torch.zeros(num_steps, 1, hidden_size)
for pos in range(num_steps):
for i in range(0, hidden_size, 2):
step_embeddings[pos, 0, i] = math.sin(pos / (10000 ** (i / hidden_size)))
step_embeddings[pos, 0, i+1] = math.cos(pos / (10000 ** (i / hidden_size)))Use pre-normalization instead of post-normalization:
# Pre-norm (better gradient flow)
x = x + self.attention(self.norm1(x))
x = x + self.ffn(self.norm2(x))
# vs Post-norm (standard)
x = self.norm1(x + self.attention(x))
x = self.norm2(x + self.ffn(x))Stop reasoning when confidence is high:
for step in range(max_steps):
x = reasoning_step(x)
confidence = torch.softmax(output_head(x), dim=-1).max()
if confidence > 0.95:
breakCache intermediate reasoning for similar problems:
reasoning_cache = {}
def cached_reasoning(problem_embedding):
key = hash(problem_embedding)
if key in reasoning_cache:
return reasoning_cache[key]
result = run_reasoning(problem_embedding)
reasoning_cache[key] = result
return resultResults from Wei et al. (2022):
| Task | Standard Prompting | CoT Prompting | Improvement |
|---|---|---|---|
| GSM8K (Math) | 17.9% | 58.1% | +40.2% |
| SVAMP (Math) | 69.9% | 78.7% | +8.8% |
| AQuA (Math) | 33.7% | 50.3% | +16.6% |
| CommonsenseQA | 67.4% | 79.2% | +11.8% |
| StrategyQA | 54.2% | 66.1% | +11.9% |
CoT benefits increase with model scale:
Model Size CoT Gain on GSM8K
----------- -----------------
1B params +2.1%
10B params +8.5%
60B params +25.3%
175B params +40.2%
540B params +52.7%
Number of Reasoning Steps:
- 1 step: 45.2% accuracy
- 2 steps: 52.3% accuracy
- 4 steps: 58.1% accuracy
- 8 steps: 57.9% accuracy (diminishing returns)
Step Embedding Impact:
- Without step embeddings: 51.4%
- With step embeddings: 58.1%
- Gain: +6.7%
Problem: CoT reasoning emerges at scale. Small models (<10B) show minimal gains.
Solution: Use models ≥60B parameters, or fine-tune smaller models on reasoning datasets.
Problem: Low-quality examples lead to degraded reasoning.
# Bad example (too vague)
"Q: Math problem. A: Use numbers."
# Good example (clear reasoning)
"Q: 3 + 5 × 2 = ?
A: Following order of operations, first multiply: 5 × 2 = 10.
Then add: 3 + 10 = 13."Problem: Too many steps wastes computation and may introduce errors.
Solution: Start with 3-5 steps. Monitor validation performance vs. computational cost.
Problem: Reasoning drifts away from the original problem.
Solution: Use cross-attention to maintain focus on the problem:
# Maintain problem context
attended = self.attention(current_step, context=original_problem)Problem: Gradients vanish/explode through long reasoning chains.
Solutions:
- Use gradient clipping:
torch.nn.utils.clip_grad_norm_(params, max_norm=1.0) - Pre-normalization architecture
- Warmup learning rate schedule
Problem: Different phrasings yield very different results.
Solution: Test multiple prompt variants:
prompts = [
"Let's solve this step by step:",
"Let's think through this carefully:",
"Breaking this down:",
"Step-by-step solution:"
]
# Run ensemble over prompts
results = [model(problem + prompt) for prompt in prompts]
final_answer = majority_vote(results)Decompose complex problems into simpler subproblems:
Problem: "What is the result of (3 + 5) × (2 + 4)?"
Subproblem 1: What is 3 + 5?
Answer: 8
Subproblem 2: What is 2 + 4?
Answer: 6
Main problem: What is 8 × 6?
Answer: 48
Use reasoning to verify candidate answers:
# Generate answer
answer = model.generate(problem)
# Verify with backward reasoning
verification_prompt = f"Verify that {answer} is correct for: {problem}"
verification = model.generate(verification_prompt)
if "incorrect" in verification.lower():
answer = model.generate(problem, temperature=0.8) # Regenerate-
Chain-of-Thought Prompting Elicits Reasoning in Large Language Models Wei et al., NeurIPS 2022 https://arxiv.org/abs/2201.11903
-
Large Language Models are Zero-Shot Reasoners Kojima et al., NeurIPS 2022 https://arxiv.org/abs/2205.11916
-
Least-to-Most Prompting Enables Complex Reasoning in Large Language Models Zhou et al., ICLR 2023 https://arxiv.org/abs/2205.10625
-
Automatic Chain of Thought Prompting in Large Language Models Zhang et al., ICLR 2023 https://arxiv.org/abs/2210.03493
- Self-Consistency: Sample multiple CoT paths and aggregate via voting
- Tree of Thoughts: Explore multiple reasoning paths in a tree structure
- Least-to-Most: Decompose-then-solve strategy
- Complexity-Based Prompting: Select few-shot examples by complexity