Skip to content

Commit 41f874c

Browse files
authored
rust data loader example (#1452)
1 parent 44c5be9 commit 41f874c

File tree

9 files changed

+588
-0
lines changed

9 files changed

+588
-0
lines changed

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
- [`loader-r-to-csv`](https://observablehq.observablehq.cloud/framework-example-loader-r-to-csv/) - Generating CSV from R
5454
- [`loader-r-to-jpeg`](https://observablehq.observablehq.cloud/framework-example-loader-r-to-jpeg/) - Generating JPEG from R
5555
- [`loader-r-to-json`](https://observablehq.observablehq.cloud/framework-example-loader-r-to-json/) - Generating JSON from R
56+
- [`loader-rust-to-json`](https://observablehq.observablehq.cloud/framework-example-loader-rust-to-json/) - Generating JSON from Rust
5657
- [`loader-snowflake`](https://observablehq.observablehq.cloud/framework-example-loader-snowflake/) - Loading data from Snowflake
5758
- [`netcdf-contours`](https://observablehq.observablehq.cloud/framework-example-netcdf-contours/) - Converting NetCDF to GeoJSON with `netcdfjs` and `d3-geo-voronoi`
5859

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.DS_Store
2+
/dist/
3+
node_modules/
4+
yarn-error.log
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[Framework examples →](../)
2+
3+
# Rust data loader to generate JSON
4+
5+
View live: <https://observablehq.observablehq.cloud/framework-example-loader-rust-to-json/>
6+
7+
This Observable Framework example demonstrates how to write a data loader in Rust that runs a Monte Carlo simulation of poker hands, calculates statistics about how often each category of hand was found, then outputs JSON.
8+
9+
The data loader lives in [`src/data/poker.json.rs`](./src/data/poker.json.rs).
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export default {
2+
root: "src"
3+
};
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"type": "module",
3+
"private": true,
4+
"scripts": {
5+
"clean": "rimraf src/.observablehq/cache",
6+
"build": "rimraf dist && observable build",
7+
"dev": "observable preview",
8+
"deploy": "observable deploy",
9+
"observable": "observable"
10+
},
11+
"dependencies": {
12+
"@observablehq/framework": "^1.8.0"
13+
},
14+
"devDependencies": {
15+
"rimraf": "^5.0.5"
16+
},
17+
"engines": {
18+
"node": ">=18"
19+
}
20+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/.observablehq/cache/
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import {max, min, rollup, sort} from "d3-array";
2+
3+
function main() {
4+
const COUNT = 100_000;
5+
const start = performance.now();
6+
7+
const counts = Array.from({length: COUNT})
8+
// Calculate the category of random hands
9+
.map(() => {
10+
const hand = Hand.random();
11+
// Convert the category into a one-element hashmap, so the reducer
12+
// can sum up all the counts for each category.
13+
return {[hand.categorize()]: 1};
14+
})
15+
// count up each category
16+
.reduce((acc, next) => {
17+
for (const [category, count] of Object.entries(next)) {
18+
acc[category] = (acc[category] ?? 0) + count;
19+
}
20+
return acc;
21+
}, {});
22+
23+
const tidyData = sort(
24+
Object.entries(counts).map(([category, count]) => ({category, count})),
25+
(d) => d.category
26+
);
27+
28+
process.stdout.write(
29+
JSON.stringify({
30+
summary: tidyData,
31+
meta: {count: COUNT, duration_ms: performance.now() - start}
32+
})
33+
);
34+
}
35+
36+
// Here, we create types for the domain model of a poker hand. Working with
37+
// specific types helps makes the rest of the code simpler.
38+
39+
class Hand {
40+
constructor(public cards: Card[]) {}
41+
42+
static random(): Hand {
43+
const cards: Card[] = [];
44+
while (cards.length < 5) {
45+
const rank = Math.floor(Math.random() * 13 + 1);
46+
const suitRand = Math.random();
47+
const suit =
48+
suitRand < 0.25 ? Suit.Clubs : suitRand < 0.5 ? Suit.Diamonds : suitRand < 0.75 ? Suit.Hearts : Suit.Spades;
49+
const card = {rank, suit};
50+
if (cards.some((c) => c.rank === card.rank && c.suit === card.suit)) {
51+
continue;
52+
}
53+
cards.push(card);
54+
}
55+
return new Hand(cards);
56+
}
57+
58+
categorize(): HandCategory {
59+
const rankCounts = rollup(
60+
this.cards,
61+
(ds) => ds.length,
62+
(d) => d.rank
63+
);
64+
const suitCounts = rollup(
65+
this.cards,
66+
(ds) => ds.length,
67+
(d) => d.rank
68+
);
69+
70+
const isFlush = suitCounts.size == 1;
71+
72+
let isStraight;
73+
74+
if (this.cards.some((c) => c.rank == 1)) {
75+
// Handle aces
76+
const minRank = min(
77+
this.cards.filter((c) => c.rank !== 1),
78+
(d) => d.rank
79+
);
80+
const maxRank = max(
81+
this.cards.filter((c) => c.rank !== 1),
82+
(d) => d.rank
83+
);
84+
isStraight = (minRank == 2 && maxRank == 5) || (minRank == 10 && maxRank == 13);
85+
} else {
86+
const minRank = min(this.cards, (d) => d.rank);
87+
const maxRank = max(this.cards, (d) => d.rank);
88+
isStraight = maxRank! - minRank! === this.cards.length - 1;
89+
}
90+
91+
if (isFlush && isStraight) {
92+
return HandCategory.StraightFlush;
93+
} else if (Array.from(rankCounts.values()).some((count) => count === 4)) {
94+
return HandCategory.FourOfAKind;
95+
} else if (
96+
Array.from(rankCounts.values()).some((count) => count === 3) &&
97+
Array.from(rankCounts.values()).some((count) => count === 2)
98+
) {
99+
return HandCategory.FullHouse;
100+
} else if (isFlush) {
101+
return HandCategory.Flush;
102+
} else if (isStraight) {
103+
return HandCategory.Straight;
104+
} else if (Array.from(rankCounts.values()).some((count) => count === 3)) {
105+
return HandCategory.ThreeOfAKind;
106+
} else if (
107+
Array.from(rankCounts.values())
108+
.filter((count) => count === 2)
109+
.length == 2
110+
) {
111+
return HandCategory.TwoPair;
112+
} else if (Array.from(rankCounts.values()).some((count) => count === 2)) {
113+
return HandCategory.OnePair;
114+
} else {
115+
return HandCategory.HighCard;
116+
}
117+
}
118+
}
119+
120+
type Card = {rank: number; suit: Suit};
121+
122+
enum Suit {
123+
Clubs,
124+
Diamonds,
125+
Hearts,
126+
Spades
127+
}
128+
129+
enum HandCategory {
130+
HighCard = "HighCard",
131+
OnePair = "OnePair",
132+
TwoPair = "TwoPair",
133+
ThreeOfAKind = "ThreeOfAKind",
134+
Straight = "Straight",
135+
Flush = "Flush",
136+
FullHouse = "FullHouse",
137+
FourOfAKind = "FourOfAKind",
138+
StraightFlush = "StraightFlush"
139+
}
140+
141+
main();
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
//! Since Framework uses rust-script, we can define dependencies here.
2+
//!
3+
//! ```cargo
4+
//! [dependencies]
5+
//! serde = { version = "1.0.203", features = ["derive"] }
6+
//! serde_json = "1.0.117"
7+
//! rand = "0.8.5"
8+
//! rayon = "1.10.0"
9+
//! ```
10+
11+
use rand::Rng;
12+
use rayon::prelude::*;
13+
use serde::Serialize;
14+
use serde_json::json;
15+
use std::collections::HashMap;
16+
17+
fn main() {
18+
const COUNT: u32 = 10_000_000;
19+
let start = std::time::Instant::now();
20+
21+
let counts = (0..COUNT)
22+
// This line breaks the work up into multiple parallel jobs.
23+
.into_par_iter()
24+
// Calculate the category of random hands
25+
.map(|_| {
26+
let hand = Hand::random();
27+
// Convert the category into a one-element hashmap, so the reducer
28+
// can sum up all the counts for each category.
29+
let mut map = HashMap::new();
30+
map.insert(hand.categorize(), 1);
31+
map
32+
})
33+
// count up each category
34+
.reduce(
35+
|| HashMap::with_capacity(10),
36+
|mut acc, map| {
37+
for (category, count) in map {
38+
*acc.entry(category).or_insert(0) += count;
39+
}
40+
acc
41+
},
42+
);
43+
44+
let mut tidy_data = counts
45+
.into_iter()
46+
.map(|(category, count)| SummaryRow { category, count })
47+
.collect::<Vec<_>>();
48+
tidy_data.sort_by_key(|data| data.category);
49+
50+
serde_json::to_writer(std::io::stdout(), &json!({
51+
"summary": tidy_data,
52+
"meta": { "count": COUNT, "duration_ms": start.elapsed().as_millis() },
53+
})).unwrap();
54+
}
55+
56+
// Here, we create types for the domain model of a poker hand. Working with
57+
// specific types helps makes the rest of the code simpler.
58+
59+
#[derive(Debug, Clone, Serialize)]
60+
struct SummaryRow {
61+
category: HandCategory,
62+
count: u32,
63+
}
64+
65+
#[derive(Debug, PartialEq, Clone, Serialize)]
66+
struct Hand(Vec<Card>);
67+
68+
#[derive(Debug, PartialEq, Clone, Copy, Serialize)]
69+
struct Card {
70+
/// 1 is an Ace, 2-10 are the numbered cards, 11 is Jack, 12 is Queen, 13 is King.
71+
rank: u8,
72+
suit: Suit,
73+
}
74+
75+
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Hash)]
76+
enum Suit {
77+
Clubs,
78+
Diamonds,
79+
Hearts,
80+
Spades,
81+
}
82+
83+
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize)]
84+
enum HandCategory {
85+
HighCard,
86+
OnePair,
87+
TwoPair,
88+
ThreeOfAKind,
89+
Straight,
90+
Flush,
91+
FullHouse,
92+
FourOfAKind,
93+
StraightFlush,
94+
}
95+
96+
// With the data domain specified, we can write the logic to generate hands and categorize them.
97+
98+
impl Hand {
99+
/// Generate a random 5 card hand
100+
fn random() -> Self {
101+
let mut rng = rand::thread_rng();
102+
let mut cards = Vec::with_capacity(5);
103+
while cards.len() < 5 {
104+
let rank = rng.gen_range(1..=13);
105+
let suit = match rng.gen_range(0..4) {
106+
0 => Suit::Clubs,
107+
1 => Suit::Diamonds,
108+
2 => Suit::Hearts,
109+
3 => Suit::Spades,
110+
_ => unreachable!(),
111+
};
112+
let card = Card { rank, suit };
113+
if cards.iter().any(|&c| c == card) { continue };
114+
cards.push(card);
115+
}
116+
Self(cards)
117+
}
118+
119+
fn categorize(&self) -> HandCategory {
120+
let rank_counts = self.0.iter().fold(HashMap::new(), |mut acc, card| {
121+
*acc.entry(card.rank).or_insert(0) += 1;
122+
acc
123+
});
124+
let suit_counts = self.0.iter().fold(HashMap::new(), |mut acc, card| {
125+
*acc.entry(card.suit).or_insert(0) += 1;
126+
acc
127+
});
128+
let is_flush = suit_counts.len() == 1;
129+
let is_straight = if self.0.iter().any(|card| card.rank == 1) {
130+
// Handle aces
131+
let min_rank = self.0.iter().map(|card| card.rank).filter(|&rank| rank != 1).min().unwrap();
132+
let max_rank = self.0.iter().map(|card| card.rank).filter(|&rank| rank != 1).max().unwrap();
133+
(min_rank == 2 && max_rank == 5) || (min_rank == 10 && max_rank == 13)
134+
} else {
135+
let min_rank = self.0.iter().map(|card| card.rank).min().unwrap();
136+
let max_rank = self.0.iter().map(|card| card.rank).max().unwrap();
137+
(max_rank - min_rank) as usize == self.0.len() - 1
138+
};
139+
140+
if is_flush && is_straight {
141+
HandCategory::StraightFlush
142+
} else if rank_counts.values().any(|&count| count == 4) {
143+
HandCategory::FourOfAKind
144+
} else if rank_counts.values().any(|&count| count == 3)
145+
&& rank_counts.values().any(|&count| count == 2)
146+
{
147+
HandCategory::FullHouse
148+
} else if is_flush {
149+
HandCategory::Flush
150+
} else if is_straight {
151+
HandCategory::Straight
152+
} else if rank_counts.values().any(|&count| count == 3) {
153+
HandCategory::ThreeOfAKind
154+
} else if rank_counts.values().filter(|&&count| count == 2).count() == 2 {
155+
HandCategory::TwoPair
156+
} else if rank_counts.values().any(|&count| count == 2) {
157+
HandCategory::OnePair
158+
} else {
159+
HandCategory::HighCard
160+
}
161+
}
162+
}

0 commit comments

Comments
 (0)