Skip to content

Commit d06e67e

Browse files
authored
feat(ads-client): invalidate cache after click or impression record (#7088)
1 parent 930e6a7 commit d06e67e

File tree

10 files changed

+457
-161
lines changed

10 files changed

+457
-161
lines changed

components/ads-client/src/client.rs

Lines changed: 85 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
use std::collections::HashMap;
77
use std::time::Duration;
88

9-
use crate::client::ad_response::{AdImage, AdResponse, AdSpoc, AdTile};
9+
use crate::client::ad_response::{
10+
pop_request_hash_from_url, AdImage, AdResponse, AdResponseValue, AdSpoc, AdTile,
11+
};
1012
use crate::client::config::AdsClientConfig;
1113
use crate::error::{RecordClickError, RecordImpressionError, ReportAdError, RequestAdsError};
1214
use crate::http_cache::{HttpCache, RequestCachePolicy};
1315
use crate::mars::MARSClient;
1416
use ad_request::{AdPlacementRequest, AdRequest};
1517
use context_id::{ContextIDComponent, DefaultContextIdCallback};
16-
use serde::de::DeserializeOwned;
1718
use url::Url;
1819
use uuid::Uuid;
1920

@@ -75,18 +76,33 @@ impl AdsClient {
7576
}
7677
}
7778

79+
#[cfg(test)]
80+
pub fn new_with_mars_client(client: MARSClient) -> Self {
81+
let context_id_component = ContextIDComponent::new(
82+
&uuid::Uuid::new_v4().to_string(),
83+
0,
84+
false,
85+
Box::new(DefaultContextIdCallback),
86+
);
87+
Self {
88+
context_id_component,
89+
client,
90+
}
91+
}
92+
7893
fn request_ads<T>(
7994
&self,
8095
ad_placement_requests: Vec<AdPlacementRequest>,
8196
options: Option<RequestCachePolicy>,
8297
) -> Result<AdResponse<T>, RequestAdsError>
8398
where
84-
T: DeserializeOwned,
99+
T: AdResponseValue,
85100
{
86101
let context_id = self.get_context_id()?;
87102
let ad_request = AdRequest::build(context_id, ad_placement_requests)?;
88103
let cache_policy = options.unwrap_or_default();
89-
let response = self.client.fetch_ads(&ad_request, &cache_policy)?;
104+
let (mut response, request_hash) = self.client.fetch_ads(&ad_request, &cache_policy)?;
105+
response.add_request_hash_to_callbacks(&request_hash);
90106
Ok(response)
91107
}
92108

@@ -118,10 +134,18 @@ impl AdsClient {
118134
}
119135

120136
pub fn record_impression(&self, impression_url: Url) -> Result<(), RecordImpressionError> {
137+
let mut impression_url = impression_url.clone();
138+
if let Some(request_hash) = pop_request_hash_from_url(&mut impression_url) {
139+
let _ = self.client.invalidate_cache_by_hash(&request_hash);
140+
}
121141
self.client.record_impression(impression_url)
122142
}
123143

124144
pub fn record_click(&self, click_url: Url) -> Result<(), RecordClickError> {
145+
let mut click_url = click_url.clone();
146+
if let Some(request_hash) = pop_request_hash_from_url(&mut click_url) {
147+
let _ = self.client.invalidate_cache_by_hash(&request_hash);
148+
}
125149
self.client.record_click(click_url)
126150
}
127151

@@ -147,9 +171,12 @@ impl AdsClient {
147171

148172
#[cfg(test)]
149173
mod tests {
150-
use crate::test_utils::{
151-
get_example_happy_image_response, get_example_happy_spoc_response,
152-
get_example_happy_uatile_response, make_happy_placement_requests,
174+
use crate::{
175+
client::config::Environment,
176+
test_utils::{
177+
get_example_happy_image_response, get_example_happy_spoc_response,
178+
get_example_happy_uatile_response, make_happy_placement_requests,
179+
},
153180
};
154181

155182
use super::*;
@@ -173,8 +200,6 @@ mod tests {
173200

174201
#[test]
175202
fn test_request_image_ads_happy() {
176-
use crate::test_utils::create_test_client;
177-
use context_id::{ContextIDComponent, DefaultContextIdCallback};
178203
viaduct_dev::init_backend_dev();
179204

180205
let expected_response = get_example_happy_image_response();
@@ -184,29 +209,18 @@ mod tests {
184209
.with_body(serde_json::to_string(&expected_response).unwrap())
185210
.create();
186211

187-
let mars_client = create_test_client(mockito::server_url());
188-
let context_id_component = ContextIDComponent::new(
189-
&uuid::Uuid::new_v4().to_string(),
190-
0,
191-
false,
192-
Box::new(DefaultContextIdCallback),
193-
);
194-
let component = AdsClient {
195-
context_id_component,
196-
client: mars_client,
197-
};
212+
let mars_client = MARSClient::new(Environment::Test, None);
213+
let ads_client = AdsClient::new_with_mars_client(mars_client);
198214

199215
let ad_placement_requests = make_happy_placement_requests();
200216

201-
let result = component.request_image_ads(ad_placement_requests, None);
217+
let result = ads_client.request_image_ads(ad_placement_requests, None);
202218

203219
assert!(result.is_ok());
204220
}
205221

206222
#[test]
207223
fn test_request_spocs_happy() {
208-
use crate::test_utils::create_test_client;
209-
use context_id::{ContextIDComponent, DefaultContextIdCallback};
210224
viaduct_dev::init_backend_dev();
211225

212226
let expected_response = get_example_happy_spoc_response();
@@ -216,29 +230,18 @@ mod tests {
216230
.with_body(serde_json::to_string(&expected_response).unwrap())
217231
.create();
218232

219-
let mars_client = create_test_client(mockito::server_url());
220-
let context_id_component = ContextIDComponent::new(
221-
&uuid::Uuid::new_v4().to_string(),
222-
0,
223-
false,
224-
Box::new(DefaultContextIdCallback),
225-
);
226-
let component = AdsClient {
227-
context_id_component,
228-
client: mars_client,
229-
};
233+
let mars_client = MARSClient::new(Environment::Test, None);
234+
let ads_client = AdsClient::new_with_mars_client(mars_client);
230235

231236
let ad_placement_requests = make_happy_placement_requests();
232237

233-
let result = component.request_spoc_ads(ad_placement_requests, None);
238+
let result = ads_client.request_spoc_ads(ad_placement_requests, None);
234239

235240
assert!(result.is_ok());
236241
}
237242

238243
#[test]
239244
fn test_request_tiles_happy() {
240-
use crate::test_utils::create_test_client;
241-
use context_id::{ContextIDComponent, DefaultContextIdCallback};
242245
viaduct_dev::init_backend_dev();
243246

244247
let expected_response = get_example_happy_uatile_response();
@@ -248,22 +251,55 @@ mod tests {
248251
.with_body(serde_json::to_string(&expected_response).unwrap())
249252
.create();
250253

251-
let mars_client = create_test_client(mockito::server_url());
252-
let context_id_component = ContextIDComponent::new(
253-
&uuid::Uuid::new_v4().to_string(),
254-
0,
255-
false,
256-
Box::new(DefaultContextIdCallback),
257-
);
258-
let component = AdsClient {
259-
context_id_component,
260-
client: mars_client,
261-
};
254+
let mars_client = MARSClient::new(Environment::Test, None);
255+
let ads_client = AdsClient::new_with_mars_client(mars_client);
262256

263257
let ad_placement_requests = make_happy_placement_requests();
264258

265-
let result = component.request_tile_ads(ad_placement_requests, None);
259+
let result = ads_client.request_tile_ads(ad_placement_requests, None);
266260

267261
assert!(result.is_ok());
268262
}
263+
264+
#[test]
265+
fn test_record_click_invalidates_cache() {
266+
viaduct_dev::init_backend_dev();
267+
let cache = HttpCache::builder("test_record_click_invalidates_cache")
268+
.build()
269+
.unwrap();
270+
let mars_client = MARSClient::new(Environment::Test, Some(cache));
271+
let ads_client = AdsClient::new_with_mars_client(mars_client);
272+
273+
let response = get_example_happy_image_response();
274+
275+
let _m1 = mockito::mock("POST", "/ads")
276+
.with_status(200)
277+
.with_header("content-type", "application/json")
278+
.with_body(serde_json::to_string(&response).unwrap())
279+
.expect(2) // we expect 2 requests to the server, one for the initial ad request and one after for the cache invalidation request
280+
.create();
281+
282+
let response = ads_client
283+
.request_image_ads(make_happy_placement_requests(), None)
284+
.unwrap();
285+
let callback_url = response.values().next().unwrap().callbacks.click.clone();
286+
287+
let _m2 = mockito::mock("GET", callback_url.path())
288+
.with_status(200)
289+
.create();
290+
291+
// Doing another request should hit the cache
292+
ads_client
293+
.request_image_ads(make_happy_placement_requests(), None)
294+
.unwrap();
295+
296+
ads_client.record_click(callback_url).unwrap();
297+
298+
ads_client
299+
.request_ads::<AdImage>(
300+
make_happy_placement_requests(),
301+
Some(RequestCachePolicy::default()),
302+
)
303+
.unwrap();
304+
}
269305
}

components/ads-client/src/client/ad_response.rs

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,37 @@
33
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
44
*/
55

6+
use crate::http_cache::RequestHash;
67
use serde::de::DeserializeOwned;
78
use serde::Deserializer;
89
use serde::{Deserialize, Serialize};
910
use std::collections::HashMap;
1011
use url::Url;
1112

1213
#[derive(Debug, Deserialize, PartialEq, Serialize)]
13-
pub struct AdResponse<T: DeserializeOwned> {
14+
pub struct AdResponse<T: AdResponseValue> {
1415
#[serde(deserialize_with = "deserialize_ad_response", flatten)]
1516
pub data: HashMap<String, Vec<T>>,
1617
}
1718

18-
impl<T: DeserializeOwned> AdResponse<T> {
19+
impl<T: AdResponseValue> AdResponse<T> {
20+
pub fn add_request_hash_to_callbacks(&mut self, request_hash: &RequestHash) {
21+
for ads in self.data.values_mut() {
22+
for ad in ads.iter_mut() {
23+
let callbacks = ad.callbacks_mut();
24+
let hash_str = request_hash.to_string();
25+
callbacks
26+
.click
27+
.query_pairs_mut()
28+
.append_pair("request_hash", &hash_str);
29+
callbacks
30+
.impression
31+
.query_pairs_mut()
32+
.append_pair("request_hash", &hash_str);
33+
}
34+
}
35+
}
36+
1937
pub fn take_first(self) -> HashMap<String, T> {
2038
self.data
2139
.into_iter()
@@ -30,10 +48,31 @@ impl<T: DeserializeOwned> AdResponse<T> {
3048
}
3149
}
3250

51+
pub fn pop_request_hash_from_url(url: &mut Url) -> Option<RequestHash> {
52+
let mut request_hash = None;
53+
let mut query = url::form_urlencoded::Serializer::new(String::new());
54+
55+
for (key, value) in url.query_pairs() {
56+
if key == "request_hash" {
57+
request_hash = Some(RequestHash::from(value.as_ref()));
58+
} else {
59+
query.append_pair(&key, &value);
60+
}
61+
}
62+
63+
let query_string = query.finish();
64+
if query_string.is_empty() {
65+
url.set_query(None);
66+
} else {
67+
url.set_query(Some(&query_string));
68+
}
69+
request_hash
70+
}
71+
3372
fn deserialize_ad_response<'de, D, T>(deserializer: D) -> Result<HashMap<String, Vec<T>>, D::Error>
3473
where
3574
D: Deserializer<'de>,
36-
T: DeserializeOwned,
75+
T: AdResponseValue,
3776
{
3877
let raw = HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
3978
let mut result = HashMap::new();
@@ -117,6 +156,28 @@ pub struct AdCallbacks {
117156
pub report: Option<Url>,
118157
}
119158

159+
pub trait AdResponseValue: DeserializeOwned {
160+
fn callbacks_mut(&mut self) -> &mut AdCallbacks;
161+
}
162+
163+
impl AdResponseValue for AdImage {
164+
fn callbacks_mut(&mut self) -> &mut AdCallbacks {
165+
&mut self.callbacks
166+
}
167+
}
168+
169+
impl AdResponseValue for AdSpoc {
170+
fn callbacks_mut(&mut self) -> &mut AdCallbacks {
171+
&mut self.callbacks
172+
}
173+
}
174+
175+
impl AdResponseValue for AdTile {
176+
fn callbacks_mut(&mut self) -> &mut AdCallbacks {
177+
&mut self.callbacks
178+
}
179+
}
180+
120181
#[cfg(test)]
121182
mod tests {
122183
use super::*;
@@ -372,4 +433,60 @@ mod tests {
372433
assert_eq!(second_ad.alt_text, Some("Third ad".to_string()));
373434
assert_eq!(second_ad.block_key, "key3");
374435
}
436+
437+
#[test]
438+
fn test_add_request_hash_to_callbacks() {
439+
let mut response = AdResponse {
440+
data: HashMap::from([(
441+
"placement_1".to_string(),
442+
vec![AdImage {
443+
alt_text: Some("An ad for a puppy".to_string()),
444+
block_key: "abc123".into(),
445+
callbacks: AdCallbacks {
446+
click: Url::parse("https://example.com/click").unwrap(),
447+
impression: Url::parse("https://example.com/impression").unwrap(),
448+
report: Some(Url::parse("https://example.com/report").unwrap()),
449+
},
450+
format: "billboard".to_string(),
451+
image_url: "https://example.com/image.png".to_string(),
452+
url: "https://example.com/ad".to_string(),
453+
}],
454+
)]),
455+
};
456+
457+
let request_hash = RequestHash::from("abc123def456");
458+
response.add_request_hash_to_callbacks(&request_hash);
459+
let callbacks = &response.data.values().next().unwrap()[0].callbacks;
460+
461+
assert!(callbacks
462+
.click
463+
.query()
464+
.unwrap_or("")
465+
.contains("request_hash=abc123def456"));
466+
assert!(callbacks
467+
.impression
468+
.query()
469+
.unwrap_or("")
470+
.contains("request_hash=abc123def456"));
471+
}
472+
473+
#[test]
474+
fn test_pop_request_hash_from_url() {
475+
let mut url_with_hash =
476+
Url::parse("https://example.com/callback?request_hash=abc123def456&other=param")
477+
.unwrap();
478+
let extracted = pop_request_hash_from_url(&mut url_with_hash);
479+
assert_eq!(extracted, Some(RequestHash::from("abc123def456")));
480+
assert_eq!(url_with_hash.query(), Some("other=param"));
481+
482+
let mut url_without_hash = Url::parse("https://example.com/callback?other=param").unwrap();
483+
let extracted_none = pop_request_hash_from_url(&mut url_without_hash);
484+
assert_eq!(extracted_none, None);
485+
assert_eq!(url_without_hash.query(), Some("other=param"));
486+
487+
let mut url_no_query = Url::parse("https://example.com/callback").unwrap();
488+
let extracted_empty = pop_request_hash_from_url(&mut url_no_query);
489+
assert_eq!(extracted_empty, None);
490+
assert_eq!(url_no_query.query(), None);
491+
}
375492
}

0 commit comments

Comments
 (0)